From 1274a3b02abb97637ba3f07ded5476a566fe6519 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 22 Aug 2024 14:26:50 +0200 Subject: [PATCH] rewrite filter logic --- molpipeline/mol2mol/filter.py | 104 ++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 36 deletions(-) diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 46151aef..e7fbf368 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -27,6 +27,34 @@ from molpipeline.utils.value_conversions import count_value_to_tuple +def _within_boundaries( + lower_bound: Optional[float], upper_bound: Optional[float], value: float +) -> bool: + """Check if a value is within the specified boundaries. + + Boundaries given as None are ignored. + + Parameters + ---------- + lower_bound: Optional[float] + Lower boundary. + upper_bound: Optional[float] + Upper boundary. + value: float + Value to check. + + Returns + ------- + bool + True if the value is within the boundaries, else False. + """ + if lower_bound is not None and value < lower_bound: + return False + if upper_bound is not None and value > upper_bound: + return False + return True + + class ElementFilter(_MolToMolPipelineElement): """ElementFilter which removes molecules containing chemical elements other than specified. @@ -227,9 +255,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: match_smarts = smarts_match.GetDescription() all_matches = value.GetSubstructMatches(Chem.MolFromSmarts(match_smarts)) min_count, max_count = self.patterns[match_smarts] - if (min_count is None or len(all_matches) >= min_count) and ( - max_count is None or len(all_matches) <= max_count - ): + if _within_boundaries(min_count, max_count, len(all_matches)): if self.mode == "any": return ( value @@ -290,9 +316,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: """ for pattern, (min_count, max_count) in self.patterns.items(): all_matches = value.GetSubstructMatches(Chem.MolFromSmiles(pattern)) - if (min_count is None or len(all_matches) >= min_count) and ( - max_count is None or len(all_matches) <= max_count - ): + if _within_boundaries(min_count, max_count, len(all_matches)): if self.mode == "any": return ( value @@ -431,13 +455,19 @@ def set_params(self, **parameters: Any) -> Self: """ parameter_copy = dict(parameters) if "descriptors" in parameter_copy: - self.patterns = parameter_copy.pop("descriptors") + self.descriptors = parameter_copy.pop("descriptors") super().set_params(**parameter_copy) return self def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Invalidate or validate molecule based on specified RDKit descriptors. + There are four possible scenarios: + - Mode = "any" & "keep_matches" = True: Needs to match at least one descriptor. + - Mode = "any" & "keep_matches" = False: Must not match any descriptor. + - Mode = "all" & "keep_matches" = True: Needs to match all descriptors. + - Mode = "all" & "keep_matches" = False: Must not match all descriptors. + Parameters ---------- value: RDKitMol @@ -450,49 +480,51 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: """ for descriptor, (min_count, max_count) in self.descriptors.items(): descriptor_value = getattr(Descriptors, descriptor)(value) - if (min_count is None or descriptor_value >= min_count) and ( - max_count is None or descriptor_value <= max_count - ): + + if _within_boundaries(min_count, max_count, descriptor_value): + # For "any" mode we can return early if a match is found if self.mode == "any": - return ( - value - if self.keep_matches - else InvalidInstance( + if not self.keep_matches: + value = InvalidInstance( self.uuid, f"Molecule contains forbidden descriptor {descriptor}.", self.name, ) - ) + return value else: + # For "all" mode we can return early if a match is not found if self.mode == "all": - return ( - value - if not self.keep_matches - else InvalidInstance( + if self.keep_matches: + value = InvalidInstance( self.uuid, - "Molecule does not match all required descriptors.", + f"Molecule does not contain required descriptor {descriptor}.", self.name, ) - ) + return value + + # If this point is reached, no or all patterns were found + # If mode is "any", finishing the loop means no match was found if self.mode == "any": - return ( - value - if not self.keep_matches - else InvalidInstance( + if self.keep_matches: + value = InvalidInstance( self.uuid, - "Molecule does not match any of the DescriptorsFilter descriptors.", + "Molecule does not match any of the required descriptors.", self.name, ) - ) - return ( - value - if self.keep_matches - else InvalidInstance( - self.uuid, - "Molecule does not match all of the DescriptorsFilter descriptors.", - self.name, - ) - ) + # else: No match with forbidden descriptors was found, return original molecule + return value + + if self.mode == "all": + if not self.keep_matches: + value = InvalidInstance( + self.uuid, + "Molecule matches all forbidden descriptors.", + self.name, + ) + # else: All required descriptors were found, return original molecule + return value + + raise ValueError(f"Invalid mode: {self.mode}") class MixtureFilter(_MolToMolPipelineElement):