From 20407f924b42447022599eaa717fcd143db16944 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 2 Oct 2024 11:47:07 +0200 Subject: [PATCH 1/3] Add functionality to compare pipelines --- molpipeline/utils/comparison.py | 101 ++++++++++++++++++++++++++++ tests/test_utils/test_comparison.py | 77 +++++++++++++++++++++ 2 files changed, 178 insertions(+) create mode 100644 molpipeline/utils/comparison.py create mode 100644 tests/test_utils/test_comparison.py diff --git a/molpipeline/utils/comparison.py b/molpipeline/utils/comparison.py new file mode 100644 index 00000000..941f5a0e --- /dev/null +++ b/molpipeline/utils/comparison.py @@ -0,0 +1,101 @@ +"""Functions for comparing pipelines.""" + +from typing import Any, TypeVar + +from molpipeline import Pipeline +from molpipeline.utils.json_operations import recursive_to_json + +_T = TypeVar("_T", list[Any], tuple[Any, ...], set[Any], dict[Any, Any], Any) + + +def remove_irrelevant_params(params: _T) -> _T: + """Remove irrelevant parameters from a dictionary. + + Parameters + ---------- + params : TypeVar + Parameters to remove irrelevant parameters from. + + Returns + ------- + TypeVar + Parameters without irrelevant parameters. + """ + if isinstance(params, list): + return [remove_irrelevant_params(val) for val in params] + if isinstance(params, tuple): + return tuple(remove_irrelevant_params(val) for val in params) + if isinstance(params, set): + return {remove_irrelevant_params(val) for val in params} + + irrelevant_params = ["n_jobs", "uuid", "error_filter_id"] + if isinstance(params, dict): + params_new = {} + for key, value in params.items(): + if key.split("__")[-1] in irrelevant_params: + continue + params_new[key] = remove_irrelevant_params(value) + return params_new + return params + + +def compare_recursive( # pylint: disable=too-many-return-statements + value_a: Any, value_b: Any +) -> bool: + """Compare two values recursively. + + Parameters + ---------- + value_a : Any + First value to compare. + value_b : Any + Second value to compare. + + Returns + ------- + bool + True if the values are the same, False otherwise. + """ + if value_a.__class__ != value_b.__class__: + return False + + if isinstance(value_a, dict): + if set(value_a.keys()) != set(value_b.keys()): + return False + for key in value_a: + if not compare_recursive(value_a[key], value_b[key]): + return False + return True + + if isinstance(value_a, (list, tuple)): + if len(value_a) != len(value_b): + return False + for val_a, val_b in zip(value_a, value_b): + if not compare_recursive(val_a, val_b): + return False + return True + return value_a == value_b + + +def check_pipelines_equivalent(pipeline_a: Pipeline, pipeline_b: Pipeline) -> bool: + """Check if two pipelines are the same. + + Parameters + ---------- + pipeline_a : Pipeline + Pipeline to compare. + pipeline_b : Pipeline + Pipeline to compare. + + Returns + ------- + bool + True if the pipelines are the same, False otherwise. + """ + if not isinstance(pipeline_a, Pipeline) or not isinstance(pipeline_b, Pipeline): + raise ValueError("Both inputs should be of type Pipeline.") + pipeline_json_a = recursive_to_json(pipeline_a) + pipeline_json_a = remove_irrelevant_params(pipeline_json_a) + pipeline_json_b = recursive_to_json(pipeline_b) + pipeline_json_b = remove_irrelevant_params(pipeline_json_b) + return compare_recursive(pipeline_json_a, pipeline_json_b) diff --git a/tests/test_utils/test_comparison.py b/tests/test_utils/test_comparison.py new file mode 100644 index 00000000..7545853f --- /dev/null +++ b/tests/test_utils/test_comparison.py @@ -0,0 +1,77 @@ +"""Test the comparison functions.""" + +from unittest import TestCase + +from sklearn.ensemble import RandomForestClassifier + +from molpipeline import Pipeline +from molpipeline.any2mol import SmilesToMol +from molpipeline.error_handling import ErrorFilter, FilterReinserter +from molpipeline.mol2any import ( + MolToConcatenatedVector, + MolToMorganFP, + MolToRDKitPhysChem, +) +from molpipeline.post_prediction import PostPredictionWrapper +from molpipeline.utils.comparison import check_pipelines_equivalent + + +def get_test_pipeline() -> Pipeline: + """Get a test pipeline.""" + + error_filter = ErrorFilter(filter_everything=True) + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ( + "mol2fp", + MolToConcatenatedVector( + [ + ("morgan", MolToMorganFP(n_bits=2048)), + ("physchem", MolToRDKitPhysChem()), + ] + ), + ), + ("error_filter", error_filter), + ("rf", RandomForestClassifier()), + ( + "filter_reinserter", + PostPredictionWrapper( + FilterReinserter.from_error_filter(error_filter, None) + ), + ), + ], + n_jobs=1, + ) + + # Set up pipeline + return pipeline + + +class TestComparison(TestCase): + """Test if functional equivalent pipelines are detected as such.""" + + def test_are_equal(self) -> None: + """Test if two equivalent pipelines are detected as such.""" + + pipeline_a = get_test_pipeline() + pipeline_b = get_test_pipeline() + self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) + + def test_are_not_equal(self) -> None: + """Test if two different pipelines are detected as such.""" + # Test changed parameters + pipeline_a = get_test_pipeline() + pipeline_b = get_test_pipeline() + pipeline_b.set_params(mol2fp__morgan__n_bits=1024) + self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b)) + + # Test changed steps + pipeline_b = get_test_pipeline() + last_step = pipeline_b.steps[-1] + pipeline_b.steps = pipeline_b.steps[:-1] + self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b)) + + # Test if adding the step back makes the pipelines equivalent + pipeline_b.steps.append(last_step) + self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) From 45a287974c05f3e569818b839bc349a4994c5b98 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 2 Oct 2024 11:49:17 +0200 Subject: [PATCH 2/3] Add return statement --- tests/test_utils/test_comparison.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_utils/test_comparison.py b/tests/test_utils/test_comparison.py index 7545853f..4c1436cd 100644 --- a/tests/test_utils/test_comparison.py +++ b/tests/test_utils/test_comparison.py @@ -17,8 +17,13 @@ def get_test_pipeline() -> Pipeline: - """Get a test pipeline.""" + """Get a test pipeline. + Returns + ------- + Pipeline + Test pipeline. + """ error_filter = ErrorFilter(filter_everything=True) pipeline = Pipeline( [ From a6c175d8f89326a786e79e3d8a080778744614d3 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Mon, 7 Oct 2024 17:58:09 +0200 Subject: [PATCH 3/3] move definition of pipeline to utils function --- tests/test_utils/test_comparison.py | 59 +++-------------------------- tests/utils/default_models.py | 48 +++++++++++++++++++++++ 2 files changed, 54 insertions(+), 53 deletions(-) create mode 100644 tests/utils/default_models.py diff --git a/tests/test_utils/test_comparison.py b/tests/test_utils/test_comparison.py index 4c1436cd..b2b5d707 100644 --- a/tests/test_utils/test_comparison.py +++ b/tests/test_utils/test_comparison.py @@ -2,55 +2,8 @@ from unittest import TestCase -from sklearn.ensemble import RandomForestClassifier - -from molpipeline import Pipeline -from molpipeline.any2mol import SmilesToMol -from molpipeline.error_handling import ErrorFilter, FilterReinserter -from molpipeline.mol2any import ( - MolToConcatenatedVector, - MolToMorganFP, - MolToRDKitPhysChem, -) -from molpipeline.post_prediction import PostPredictionWrapper from molpipeline.utils.comparison import check_pipelines_equivalent - - -def get_test_pipeline() -> Pipeline: - """Get a test pipeline. - - Returns - ------- - Pipeline - Test pipeline. - """ - error_filter = ErrorFilter(filter_everything=True) - pipeline = Pipeline( - [ - ("smi2mol", SmilesToMol()), - ( - "mol2fp", - MolToConcatenatedVector( - [ - ("morgan", MolToMorganFP(n_bits=2048)), - ("physchem", MolToRDKitPhysChem()), - ] - ), - ), - ("error_filter", error_filter), - ("rf", RandomForestClassifier()), - ( - "filter_reinserter", - PostPredictionWrapper( - FilterReinserter.from_error_filter(error_filter, None) - ), - ), - ], - n_jobs=1, - ) - - # Set up pipeline - return pipeline +from tests.utils.default_models import get_morgan_physchem_rf_pipeline class TestComparison(TestCase): @@ -59,20 +12,20 @@ class TestComparison(TestCase): def test_are_equal(self) -> None: """Test if two equivalent pipelines are detected as such.""" - pipeline_a = get_test_pipeline() - pipeline_b = get_test_pipeline() + pipeline_a = get_morgan_physchem_rf_pipeline() + pipeline_b = get_morgan_physchem_rf_pipeline() self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) def test_are_not_equal(self) -> None: """Test if two different pipelines are detected as such.""" # Test changed parameters - pipeline_a = get_test_pipeline() - pipeline_b = get_test_pipeline() + pipeline_a = get_morgan_physchem_rf_pipeline() + pipeline_b = get_morgan_physchem_rf_pipeline() pipeline_b.set_params(mol2fp__morgan__n_bits=1024) self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b)) # Test changed steps - pipeline_b = get_test_pipeline() + pipeline_b = get_morgan_physchem_rf_pipeline() last_step = pipeline_b.steps[-1] pipeline_b.steps = pipeline_b.steps[:-1] self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b)) diff --git a/tests/utils/default_models.py b/tests/utils/default_models.py new file mode 100644 index 00000000..a82fe40d --- /dev/null +++ b/tests/utils/default_models.py @@ -0,0 +1,48 @@ +"""This module contains the default models used for testing molpipeline functions and classes.""" + +from sklearn.ensemble import RandomForestClassifier + +from molpipeline import Pipeline +from molpipeline.any2mol import SmilesToMol +from molpipeline.error_handling import ErrorFilter, FilterReinserter +from molpipeline.mol2any import ( + MolToConcatenatedVector, + MolToMorganFP, + MolToRDKitPhysChem, +) +from molpipeline.post_prediction import PostPredictionWrapper + + +def get_morgan_physchem_rf_pipeline() -> Pipeline: + """Get a pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier. + + Returns + ------- + Pipeline + A pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier. + """ + error_filter = ErrorFilter(filter_everything=True) + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ( + "mol2fp", + MolToConcatenatedVector( + [ + ("morgan", MolToMorganFP(n_bits=2048)), + ("physchem", MolToRDKitPhysChem()), + ] + ), + ), + ("error_filter", error_filter), + ("rf", RandomForestClassifier()), + ( + "filter_reinserter", + PostPredictionWrapper( + FilterReinserter.from_error_filter(error_filter, None) + ), + ), + ], + n_jobs=1, + ) + return pipeline