From 0a02cdab7522fe06aabf29ceebea02e5cff0f65f Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Thu, 6 Jun 2024 17:24:35 +0200 Subject: [PATCH] utils/mol2concatinated_vector: changes for xai - Add helper class SubpipelineExtractor to get certain parts of an existing Pipeline. - Add property to mol2concatinated_vector to extract total number of features. --- .../mol2any/mol2concatinated_vector.py | 15 + molpipeline/utils/subpipeline.py | 384 ++++++++++++++++++ .../test_mol2any/test_mol2concatenated.py | 48 ++- tests/test_utils/test_subpipeline.py | 336 +++++++++++++++ 4 files changed, 777 insertions(+), 6 deletions(-) create mode 100644 molpipeline/utils/subpipeline.py create mode 100644 tests/test_utils/test_subpipeline.py diff --git a/molpipeline/mol2any/mol2concatinated_vector.py b/molpipeline/mol2any/mol2concatinated_vector.py index bf85256e..490baadc 100644 --- a/molpipeline/mol2any/mol2concatinated_vector.py +++ b/molpipeline/mol2any/mol2concatinated_vector.py @@ -72,6 +72,21 @@ def element_list(self) -> list[tuple[str, MolToAnyPipelineElement]]: """Get pipeline elements.""" return self._element_list + @property + def n_features(self) -> int: + """Calculates and returns the number of features.""" + feature_count = 0 + for _, element in self._element_list: + if hasattr(element, "n_features"): + feature_count += element.n_features + elif hasattr(element, "n_bits"): + feature_count += element.n_bits + else: + raise AssertionError( + f"Element {element} does not have n_features or n_bits." + ) + return feature_count + def get_params(self, deep: bool = True) -> dict[str, Any]: """Return all parameters defining the object. diff --git a/molpipeline/utils/subpipeline.py b/molpipeline/utils/subpipeline.py new file mode 100644 index 00000000..40cb27ef --- /dev/null +++ b/molpipeline/utils/subpipeline.py @@ -0,0 +1,384 @@ +"""Helper functions to extract subpipelines from a pipeline.""" + +from __future__ import annotations + +from typing import Any, Callable + +from sklearn.base import BaseEstimator + +from molpipeline import FilterReinserter, Pipeline, PostPredictionWrapper +from molpipeline.abstract_pipeline_elements.core import ( + AnyToMolPipelineElement, + MolToAnyPipelineElement, +) + + +def _get_molecule_reading_position_from_pipeline(pipeline: Pipeline) -> int | None: + """Heuristic to select the position of the central molecule reading element in a pipeline. + + This function searches the last AnyToMolPipelineElement in the pipeline. We select the last + AnyToMolPipelineElement in the pipeline because we have some standardization pipelines + that write the molecule a smiles and read them back in to ensure they are readable. + + Parameters + ---------- + pipeline: Pipeline + The pipeline to search for the molecule reading element. + + Returns + ------- + int | None + The position of the molecule reading element in the pipeline. + """ + for i, step in enumerate(reversed(pipeline.steps)): + if isinstance(step[1], AnyToMolPipelineElement): + return len(pipeline.steps) - i - 1 + return None + + +def _get_model_element_position_from_pipeline(pipeline: Pipeline) -> int | None: + """Heuristic to select the position of the machine learning estimator model in a pipeline. + + Parameters + ---------- + pipeline: Pipeline + The pipeline to search for the model element. + + Returns + ------- + int | None + The position of the model element in the pipeline or None if no model element is found. + """ + for i, step in enumerate(reversed(pipeline.steps)): + if isinstance(step[1], BaseEstimator): + if isinstance(step[1], PostPredictionWrapper): + # skip PostPredictionWrappers. TODO is this reasonable? + continue + return len(pipeline.steps) - i - 1 + return None + + +def _get_featurization_element_position_from_pipeline(pipeline: Pipeline) -> int | None: + """Heuristic to select the position of the featurization element in a pipeline. + + Parameters + ---------- + pipeline: Pipeline + The pipeline to search for the featurization element. + + Returns + ------- + int | None + The position of the featurization element in the pipeline or None if no featurization element is found. + + """ + for i, step in enumerate(reversed(pipeline.steps)): + if isinstance(step[1], MolToAnyPipelineElement): + return len(pipeline.steps) - i - 1 + return None + + +class SubpipelineExtractor: + """A helper class to extract parts of a pipeline.""" + + def __init__(self, pipeline: Pipeline) -> None: + """Initialize the SubpipelineExtractor. + + Parameters + ---------- + pipeline : Pipeline + The pipeline to extract subpipelines from. + """ + self.pipeline = pipeline + + def _get_index_of_element_by_id(self, element: Any) -> int | None: + """Get the index of an element by id (the pointer or memory address). + + Parameters + ---------- + element : Any + The element to extract. + + Returns + ------- + int | None + The index of the element or None if the element was not found. + """ + for i, (_, pipeline_element) in enumerate(self.pipeline.steps): + if id(pipeline_element) == id(element): + return i + return None + + def _get_index_of_element_by_name(self, element_name: str) -> int | None: + """Get the index of an element by name. + + Parameters + ---------- + element_name : str + The name of the element to extract. + + Returns + ------- + int | None + The index of the element or None if the element was not found. + """ + for i, (name, _) in enumerate(self.pipeline.steps): + if name == element_name: + return i + return None + + def _extract_single_element_index( + self, + element_name: str | None, + get_index_function: Callable[[Pipeline], int | None], + ) -> int | None: + """Extract the index of a single element from the pipeline. + + Parameters + ---------- + element_name : str | None + The name of the element to extract. + get_index_function : Callable[[Pipeline], int | None] + A function that returns the index of the element to extract. + + Returns + ------- + Any | None + The index of the extracted element or None if the element was not found. + """ + if element_name is not None: + return self._get_index_of_element_by_name(element_name) + return get_index_function(self.pipeline) + + def _extract_single_element( + self, + element_name: str | None, + get_index_function: Callable[[Pipeline], int | None], + ) -> Any | None: + """Extract a single element from the pipeline. + + Parameters + ---------- + element_name : str | None + The name of the element to extract. + get_index_function : Callable[[Pipeline], int | None] + A function that returns the index of the element to extract. + + Returns + ------- + Any | None + The extracted element or None if the element was not found. + """ + if element_name is not None: + # if a name is provided, access the element by name + return self.pipeline.named_steps[element_name] + element_index = self._extract_single_element_index(None, get_index_function) + if element_index is None: + return None + return self.pipeline.steps[element_index][1] + + def get_molecule_reader_element( + self, element_name: str | None = None + ) -> AnyToMolPipelineElement | None: + """Get the molecule reader element from the pipeline, e.g. a SmilesToMol element. + + Parameters + ---------- + element_name : str | None + The name of the element to extract. + + Returns + ------- + AnyToMolPipelineElement | None + The extracted molecule reader element or None if the element was not found. + """ + return self._extract_single_element( + element_name, + _get_molecule_reading_position_from_pipeline, + ) + + def get_featurization_element( + self, element_name: str | None = None + ) -> BaseEstimator | None: + """Get the featurization element from the pipeline, e.g., a MolToMorganFP element. + + Parameters + ---------- + element_name : str | None + The name of the element to extract. + + Returns + ------- + BaseEstimator | None + The extracted featurization element or None if the element was not found. + """ + return self._extract_single_element( + element_name, _get_featurization_element_position_from_pipeline + ) + + def get_model_element( + self, element_name: str | None = None + ) -> BaseEstimator | None: + """Get the machine learning model element from the pipeline, e.g. a RandomForestClassifier. + + Parameters + ---------- + element_name : str | None + The name of the element to extract. + + Returns + ------- + BaseEstimator | None + The extracted model element or None if the element was not found. + """ + return self._extract_single_element( + element_name, _get_model_element_position_from_pipeline + ) + + def _get_subpipline_from_start( + self, + element_name: str | None, + start_get_index_function: Callable[[Pipeline], int | None], + ) -> Pipeline | None: + """Get a subpipeline up to a specific element starting from the first element of the original pipeline. + + Parameters + ---------- + element_name : str | None + The name of the element to extract. + start_get_index_function : Callable[[Pipeline], int | None] + A function that returns the index of the subpipline's last element. + + Returns + ------- + Pipeline | None + The extracted subpipeline or None if the corresponding last element was not found. + """ + element_index = self._extract_single_element_index( + element_name, start_get_index_function + ) + if element_index is None: + return None + return Pipeline(steps=self.pipeline.steps[: element_index + 1]) + + def get_molecule_reader_subpipeline( + self, element_name: str | None = None + ) -> Pipeline | None: + """Get a subpipeline up to the molecule reading element. + + Note that standardization steps, like salt removal, are not guaranteed to be included. + + Parameters + ---------- + element_name : str | None + The name of the last element in the subpipeline to extract. + + Returns + ------- + Pipeline | None + The extracted subpipeline or None if the corresponding last element was not found. + """ + return self._get_subpipline_from_start( + element_name, _get_molecule_reading_position_from_pipeline + ) + + def get_featurization_subpipeline( + self, element_name: str | None = None + ) -> Pipeline | None: + """Get a subpipeline up to the featurization element. + + Parameters + ---------- + element_name : str | None + The name of the last element in the subpipeline to extract. + + Returns + ------- + Pipeline | None + The extracted subpipeline or None if the corresponding last element was not found. + """ + return self._get_subpipline_from_start( + element_name, _get_featurization_element_position_from_pipeline + ) + + def get_model_subpipeline(self, element_name: str | None = None) -> Pipeline | None: + """Get a subpipeline up to the machine learning model element. + + Parameters + ---------- + element_name : str | None + The name of the last element in the subpipeline to extract. + + Returns + ------- + Pipeline | None + The extracted subpipeline or None if the corresponding last element was not found. + """ + return self._get_subpipline_from_start( + element_name, _get_model_element_position_from_pipeline + ) + + def get_subpipeline( + self, + first_element: Any, + second_element: Any, + first_offset: int = 0, + second_offset: int = 0, + ) -> Pipeline | None: + """Get a subpipeline between two elements. + + This function only checks the names of the elements. + If the elements are not found or the second element is before the first element, a ValueError is raised. + + Parameters + ---------- + first_element : Any + The first element of the subpipeline. + second_element : Any + The second element of the subpipeline. + first_offset : int + The offset to apply to the first element. + second_offset : int + The offset to apply to the second element. + + Returns + ------- + Pipeline | None + The extracted subpipeline or None if the elements were not found. + """ + first_element_index = self._get_index_of_element_by_id(first_element) + if first_element_index is None: + raise ValueError(f"Element {first_element} not found in pipeline.") + second_element_index = self._get_index_of_element_by_id(second_element) + if second_element_index is None: + raise ValueError(f"Element {second_element} not found in pipeline.") + + # apply user-defined offsets + first_element_index += first_offset + second_element_index += second_offset + + if second_element_index < first_element_index: + raise ValueError( + f"Element {second_element} must be after element {first_element}." + ) + return Pipeline( + steps=self.pipeline.steps[first_element_index : second_element_index + 1] + ) + + def get_all_filter_reinserter_fill_values(self) -> list[Any]: + """Get all fill values for FilterReinserter elements in the pipeline. + + Returns + ------- + list[Any] + The fill values for all FilterReinserter elements in the pipeline. + """ + fill_values = set() + for _, step in self.pipeline.steps: + if isinstance(step, FilterReinserter): + fill_values.add(step.fill_value) + if isinstance(step, PostPredictionWrapper) and isinstance( + step.wrapped_estimator, FilterReinserter + ): + fill_values.add(step.wrapped_estimator.fill_value) + return list(fill_values) diff --git a/tests/test_elements/test_mol2any/test_mol2concatenated.py b/tests/test_elements/test_mol2any/test_mol2concatenated.py index 420794ba..5bb57742 100644 --- a/tests/test_elements/test_mol2any/test_mol2concatenated.py +++ b/tests/test_elements/test_mol2any/test_mol2concatenated.py @@ -14,6 +14,7 @@ from molpipeline.mol2any import ( MolToConcatenatedVector, MolToMorganFP, + MolToNetCharge, MolToRDKitPhysChem, ) from tests.utils.fingerprints import fingerprints_to_numpy @@ -23,12 +24,7 @@ class TestConcatenatedFingerprint(unittest.TestCase): """Unittest for MolToConcatenatedVector, which calculates concatenated fingerprints.""" def test_generation(self) -> None: - """Test if the feature concatenation works as expected. - - Returns - ------- - None - """ + """Test if the feature concatenation works as expected.""" fingerprint_morgan_output_types: tuple[Any, ...] = get_args( Literal[ "sparse", @@ -95,6 +91,46 @@ def test_generation(self) -> None: self.assertTrue(np.allclose(output, output2)) self.assertTrue(np.allclose(output, output3)) + def test_n_features(self) -> None: + """Test getting the number of features in the concatenated vector.""" + + physchem_elem = ( + "RDKitPhysChem", + MolToRDKitPhysChem(), + ) + morgan_elem = ( + "MorganFP", + MolToMorganFP(n_bits=16), + ) + net_charge_elem = ("NetCharge", MolToNetCharge()) + + self.assertEqual( + MolToConcatenatedVector([physchem_elem]).n_features, + physchem_elem[1].n_features, + ) + self.assertEqual( + MolToConcatenatedVector([morgan_elem]).n_features, + 16, + ) + self.assertEqual( + MolToConcatenatedVector([net_charge_elem]).n_features, + net_charge_elem[1].n_features, + ) + self.assertEqual( + MolToConcatenatedVector([physchem_elem, morgan_elem]).n_features, + physchem_elem[1].n_features + 16, + ) + self.assertEqual( + MolToConcatenatedVector([net_charge_elem, morgan_elem]).n_features, + net_charge_elem[1].n_features + 16, + ) + self.assertEqual( + MolToConcatenatedVector( + [net_charge_elem, morgan_elem, physchem_elem] + ).n_features, + net_charge_elem[1].n_features + 16 + physchem_elem[1].n_features, + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_utils/test_subpipeline.py b/tests/test_utils/test_subpipeline.py new file mode 100644 index 00000000..672f203e --- /dev/null +++ b/tests/test_utils/test_subpipeline.py @@ -0,0 +1,336 @@ +"""Test SubpipelineExtractor.""" + +import unittest + +import numpy as np +from sklearn.ensemble import RandomForestClassifier + +from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper +from molpipeline.any2mol import SmilesToMol +from molpipeline.mol2any import MolToMorganFP, MolToSmiles +from molpipeline.mol2mol import SaltRemover +from molpipeline.utils.subpipeline import SubpipelineExtractor + + +class TestSubpipelineExtractor(unittest.TestCase): + """Test SubpipelineExtractor.""" + + def test_get_molecule_reader_element(self) -> None: + """Test extracting molecule reader element from pipelines.""" + + # test basic example + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ] + ) + extractor = SubpipelineExtractor(pipeline) + self.assertIs(extractor.get_molecule_reader_element(), pipeline.steps[0][1]) + + # test with multiple molecule readers + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("mol2smi", MolToSmiles()), + ("smi2mol2", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ] + ) + extractor = SubpipelineExtractor(pipeline) + self.assertIs(extractor.get_molecule_reader_element(), pipeline.steps[2][1]) + + def test_get_featurization_element(self) -> None: + """Test extracting featurization element from pipelines.""" + + # test basic example + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ] + ) + extractor = SubpipelineExtractor(pipeline) + self.assertIs(extractor.get_featurization_element(), pipeline.steps[1][1]) + + # test with PostPredictionWrapper + error_filter = ErrorFilter() + error_reinserter = PostPredictionWrapper( + FilterReinserter.from_error_filter(error_filter, None) + ) + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("error_filter", error_filter), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ( + "error_reinserter", + error_reinserter, + ), + ] + ) + extractor = SubpipelineExtractor(pipeline) + self.assertIs(extractor.get_featurization_element(), pipeline.steps[2][1]) + + def test_get_model_element(self) -> None: + """Test extracting model element from pipeline.""" + + # test basic example + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ] + ) + extractor = SubpipelineExtractor(pipeline) + self.assertIs(extractor.get_model_element(), pipeline.steps[2][1]) + + # test with PostPredictionWrapper + error_filter = ErrorFilter() + error_reinserter = PostPredictionWrapper( + FilterReinserter.from_error_filter(error_filter, None) + ) + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("error_filter", error_filter), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ( + "error_reinserter", + error_reinserter, + ), + ] + ) + extractor = SubpipelineExtractor(pipeline) + self.assertIs(extractor.get_model_element(), pipeline.steps[3][1]) + + def test_get_molecule_reader_subpipeline(self) -> None: + """Test extracting subpipeline up to the molecule reader element from pipelines.""" + + # test basic example + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ] + ) + extractor = SubpipelineExtractor(pipeline) + subpipeline = extractor.get_molecule_reader_subpipeline() + self.assertIsInstance(subpipeline, Pipeline) + self.assertEqual(len(subpipeline.steps), 1) # type: ignore[union-attr] + self.assertIs(subpipeline.steps[0], pipeline.steps[0]) # type: ignore[union-attr] + + # test with multiple molecule readers + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("mol2smi", MolToSmiles()), + ("smi2mol2", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ] + ) + extractor = SubpipelineExtractor(pipeline) + subpipeline = extractor.get_molecule_reader_subpipeline() + self.assertIsInstance(subpipeline, Pipeline) + self.assertEqual(len(subpipeline.steps), 3) # type: ignore[union-attr] + for i, subpipe_step in enumerate(subpipeline.steps): # type: ignore[union-attr] + self.assertIs(subpipe_step, pipeline.steps[i]) + + def test_get_featurization_subpipeline(self) -> None: + """Test extracting subpipeline up to the featurization element from pipelines.""" + + # test basic example + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ] + ) + extractor = SubpipelineExtractor(pipeline) + subpipeline = extractor.get_featurization_subpipeline() + self.assertIsInstance(subpipeline, Pipeline) + self.assertEqual(len(subpipeline.steps), 2) # type: ignore[union-attr] + for i, subpipe_step in enumerate(subpipeline.steps): # type: ignore[union-attr] + self.assertIs(subpipe_step, pipeline.steps[i]) + + # test with PostPredictionWrapper + error_filter = ErrorFilter() + error_reinserter = PostPredictionWrapper( + FilterReinserter.from_error_filter(error_filter, None) + ) + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("error_filter", error_filter), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ( + "error_reinserter", + error_reinserter, + ), + ] + ) + extractor = SubpipelineExtractor(pipeline) + subpipeline = extractor.get_featurization_subpipeline() + self.assertIsInstance(subpipeline, Pipeline) + self.assertEqual(len(subpipeline.steps), 3) # type: ignore[union-attr] + for i, subpipe_step in enumerate(subpipeline.steps): # type: ignore[union-attr] + self.assertIs(subpipe_step, pipeline.steps[i]) + + def test_get_model_subpipeline(self) -> None: + """Test extracting subpipeline up to the model element from pipelines.""" + + # test basic example + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ] + ) + extractor = SubpipelineExtractor(pipeline) + subpipeline = extractor.get_model_subpipeline() + self.assertIsInstance(subpipeline, Pipeline) + self.assertEqual(len(subpipeline.steps), 3) # type: ignore[union-attr] + for i, subpipe_step in enumerate(subpipeline.steps): # type: ignore[union-attr] + self.assertIs(subpipe_step, pipeline.steps[i]) + + # test with PostPredictionWrapper + error_filter = ErrorFilter() + error_reinserter = PostPredictionWrapper( + FilterReinserter.from_error_filter(error_filter, None) + ) + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("error_filter", error_filter), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ( + "error_reinserter", + error_reinserter, + ), + ] + ) + extractor = SubpipelineExtractor(pipeline) + subpipeline = extractor.get_model_subpipeline() + self.assertIsInstance(subpipeline, Pipeline) + self.assertEqual(len(subpipeline.steps), 4) # type: ignore[union-attr] + for i, subpipe_step in enumerate(subpipeline.steps): # type: ignore[union-attr] + self.assertIs(subpipe_step, pipeline.steps[i]) + + def test_get_subpipeline(self) -> None: + """Test extracting subpipeline as a certain interval from the original pipeline.""" + + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("salt_remover", SaltRemover()), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ] + ) + extractor = SubpipelineExtractor(pipeline) + reader_element = extractor.get_molecule_reader_element("smi2mol") + self.assertIs(reader_element, pipeline.steps[0][1]) + feature_element = extractor.get_featurization_element("morgan") + self.assertIs(feature_element, pipeline.steps[2][1]) + model_element = extractor.get_model_element("model") + self.assertIs(model_element, pipeline.steps[3][1]) + + # test smi2mol to morgan + subpipeline_reader_feature = extractor.get_subpipeline( + reader_element, feature_element + ) + self.assertIsInstance(subpipeline_reader_feature, Pipeline) + self.assertEqual(len(subpipeline_reader_feature.steps), 3) # type: ignore[union-attr] + self.assertIs(subpipeline_reader_feature.steps[0], pipeline.steps[0]) # type: ignore[union-attr] + self.assertIs(subpipeline_reader_feature.steps[1], pipeline.steps[1]) # type: ignore[union-attr] + self.assertIs(subpipeline_reader_feature.steps[2], pipeline.steps[2]) # type: ignore[union-attr] + + # test smi2mol to model + subpipeline_reader_model = extractor.get_subpipeline( + reader_element, model_element + ) + self.assertIsInstance(subpipeline_reader_model, Pipeline) + self.assertEqual(len(subpipeline_reader_model.steps), 4) # type: ignore[union-attr] + self.assertIs(subpipeline_reader_model.steps[0], pipeline.steps[0]) # type: ignore[union-attr] + self.assertIs(subpipeline_reader_model.steps[1], pipeline.steps[1]) # type: ignore[union-attr] + self.assertIs(subpipeline_reader_model.steps[2], pipeline.steps[2]) # type: ignore[union-attr] + self.assertIs(subpipeline_reader_model.steps[3], pipeline.steps[3]) # type: ignore[union-attr] + + # test morgan to model + subpipeline_feature_model = extractor.get_subpipeline( + feature_element, model_element + ) + self.assertIsInstance(subpipeline_feature_model, Pipeline) + self.assertEqual(len(subpipeline_feature_model.steps), 2) # type: ignore[union-attr] + self.assertIs(subpipeline_feature_model.steps[0], pipeline.steps[2]) # type: ignore[union-attr] + self.assertIs(subpipeline_feature_model.steps[1], pipeline.steps[3]) # type: ignore[union-attr] + + # test morgan to morgan + subpipeline_feature_feature = extractor.get_subpipeline( + feature_element, feature_element + ) + self.assertIsInstance(subpipeline_feature_feature, Pipeline) + self.assertEqual(len(subpipeline_feature_feature.steps), 1) # type: ignore[union-attr] + self.assertIs(subpipeline_feature_feature.steps[0], pipeline.steps[2]) # type: ignore[union-attr] + + # test the first element comes after the second element + self.assertRaises( + ValueError, + extractor.get_subpipeline, + feature_element, + reader_element, + ) + + element_not_in_pipeline = SmilesToMol() + + # test element not in pipeline raises an exception + self.assertRaises( + ValueError, + extractor.get_subpipeline, + element_not_in_pipeline, + feature_element, + ) + self.assertRaises( + ValueError, + extractor.get_subpipeline, + reader_element, + element_not_in_pipeline, + ) + + def test_get_all_filter_reinserter_fill_values(self) -> None: + """Test extracting all FilterReinserter fill values from pipelines.""" + + test_fill_values = [None, np.nan] + + for test_fill_value in test_fill_values: + error_filter = ErrorFilter() + error_reinserter = PostPredictionWrapper( + FilterReinserter.from_error_filter(error_filter, test_fill_value) + ) + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("error_filter", error_filter), + ("morgan", MolToMorganFP(radius=1, n_bits=64)), + ("model", RandomForestClassifier()), + ( + "error_reinserter", + error_reinserter, + ), + ] + ) + extractor = SubpipelineExtractor(pipeline) + fill_values = extractor.get_all_filter_reinserter_fill_values() + self.assertEqual(fill_values, [test_fill_value])