Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to compare pipelines #95

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions molpipeline/utils/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Functions for comparing pipelines."""

from typing import Any, TypeVar

from molpipeline import Pipeline
from molpipeline.utils.json_operations import recursive_to_json

_T = TypeVar("_T", list[Any], tuple[Any, ...], set[Any], dict[Any, Any], Any)


def remove_irrelevant_params(params: _T) -> _T:
"""Remove irrelevant parameters from a dictionary.

Parameters
----------
params : TypeVar
Parameters to remove irrelevant parameters from.

Returns
-------
TypeVar
Parameters without irrelevant parameters.
"""
if isinstance(params, list):
return [remove_irrelevant_params(val) for val in params]
if isinstance(params, tuple):
return tuple(remove_irrelevant_params(val) for val in params)
if isinstance(params, set):
return {remove_irrelevant_params(val) for val in params}

irrelevant_params = ["n_jobs", "uuid", "error_filter_id"]
if isinstance(params, dict):
params_new = {}
for key, value in params.items():
if key.split("__")[-1] in irrelevant_params:
continue
params_new[key] = remove_irrelevant_params(value)
return params_new
return params


def compare_recursive( # pylint: disable=too-many-return-statements
value_a: Any, value_b: Any
) -> bool:
"""Compare two values recursively.

Parameters
----------
value_a : Any
First value to compare.
value_b : Any
Second value to compare.

Returns
-------
bool
True if the values are the same, False otherwise.
"""
if value_a.__class__ != value_b.__class__:
return False

if isinstance(value_a, dict):
if set(value_a.keys()) != set(value_b.keys()):
return False
for key in value_a:
if not compare_recursive(value_a[key], value_b[key]):
return False
return True

if isinstance(value_a, (list, tuple)):
if len(value_a) != len(value_b):
return False
for val_a, val_b in zip(value_a, value_b):
if not compare_recursive(val_a, val_b):
return False
return True
return value_a == value_b


def check_pipelines_equivalent(pipeline_a: Pipeline, pipeline_b: Pipeline) -> bool:
"""Check if two pipelines are the same.

Parameters
----------
pipeline_a : Pipeline
Pipeline to compare.
pipeline_b : Pipeline
Pipeline to compare.

Returns
-------
bool
True if the pipelines are the same, False otherwise.
"""
if not isinstance(pipeline_a, Pipeline) or not isinstance(pipeline_b, Pipeline):
raise ValueError("Both inputs should be of type Pipeline.")
pipeline_json_a = recursive_to_json(pipeline_a)
pipeline_json_a = remove_irrelevant_params(pipeline_json_a)
pipeline_json_b = recursive_to_json(pipeline_b)
pipeline_json_b = remove_irrelevant_params(pipeline_json_b)
return compare_recursive(pipeline_json_a, pipeline_json_b)
82 changes: 82 additions & 0 deletions tests/test_utils/test_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Test the comparison functions."""

from unittest import TestCase

from sklearn.ensemble import RandomForestClassifier

from molpipeline import Pipeline
from molpipeline.any2mol import SmilesToMol
from molpipeline.error_handling import ErrorFilter, FilterReinserter
from molpipeline.mol2any import (
MolToConcatenatedVector,
MolToMorganFP,
MolToRDKitPhysChem,
)
from molpipeline.post_prediction import PostPredictionWrapper
from molpipeline.utils.comparison import check_pipelines_equivalent


def get_test_pipeline() -> Pipeline:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have such a function already in the library somewhere? Maybe put it into the test utils

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it to tests/utils/default_models.py.
We have done something similar for Chemprop (see test_extras/test_chemprop/chemprop_test_utils/default_models.py)

"""Get a test pipeline.

Returns
-------
Pipeline
Test pipeline.
"""
error_filter = ErrorFilter(filter_everything=True)
pipeline = Pipeline(
[
("smi2mol", SmilesToMol()),
(
"mol2fp",
MolToConcatenatedVector(
[
("morgan", MolToMorganFP(n_bits=2048)),
("physchem", MolToRDKitPhysChem()),
]
),
),
("error_filter", error_filter),
("rf", RandomForestClassifier()),
(
"filter_reinserter",
PostPredictionWrapper(
FilterReinserter.from_error_filter(error_filter, None)
),
),
],
n_jobs=1,
)

# Set up pipeline
return pipeline


class TestComparison(TestCase):
"""Test if functional equivalent pipelines are detected as such."""

def test_are_equal(self) -> None:
"""Test if two equivalent pipelines are detected as such."""

pipeline_a = get_test_pipeline()
pipeline_b = get_test_pipeline()
self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b))

def test_are_not_equal(self) -> None:
"""Test if two different pipelines are detected as such."""
# Test changed parameters
pipeline_a = get_test_pipeline()
pipeline_b = get_test_pipeline()
pipeline_b.set_params(mol2fp__morgan__n_bits=1024)
self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b))

# Test changed steps
pipeline_b = get_test_pipeline()
last_step = pipeline_b.steps[-1]
pipeline_b.steps = pipeline_b.steps[:-1]
self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b))

# Test if adding the step back makes the pipelines equivalent
pipeline_b.steps.append(last_step)
self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b))
Loading