Skip to content

Commit

Permalink
rewrite filter logic (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann authored Aug 22, 2024
1 parent 81ffb7c commit 9fed198
Showing 1 changed file with 68 additions and 36 deletions.
104 changes: 68 additions & 36 deletions molpipeline/mol2mol/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,34 @@
from molpipeline.utils.value_conversions import count_value_to_tuple


def _within_boundaries(
lower_bound: Optional[float], upper_bound: Optional[float], value: float
) -> bool:
"""Check if a value is within the specified boundaries.
Boundaries given as None are ignored.
Parameters
----------
lower_bound: Optional[float]
Lower boundary.
upper_bound: Optional[float]
Upper boundary.
value: float
Value to check.
Returns
-------
bool
True if the value is within the boundaries, else False.
"""
if lower_bound is not None and value < lower_bound:
return False
if upper_bound is not None and value > upper_bound:
return False
return True


class ElementFilter(_MolToMolPipelineElement):
"""ElementFilter which removes molecules containing chemical elements other than specified.
Expand Down Expand Up @@ -227,9 +255,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
match_smarts = smarts_match.GetDescription()
all_matches = value.GetSubstructMatches(Chem.MolFromSmarts(match_smarts))
min_count, max_count = self.patterns[match_smarts]
if (min_count is None or len(all_matches) >= min_count) and (
max_count is None or len(all_matches) <= max_count
):
if _within_boundaries(min_count, max_count, len(all_matches)):
if self.mode == "any":
return (
value
Expand Down Expand Up @@ -290,9 +316,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
"""
for pattern, (min_count, max_count) in self.patterns.items():
all_matches = value.GetSubstructMatches(Chem.MolFromSmiles(pattern))
if (min_count is None or len(all_matches) >= min_count) and (
max_count is None or len(all_matches) <= max_count
):
if _within_boundaries(min_count, max_count, len(all_matches)):
if self.mode == "any":
return (
value
Expand Down Expand Up @@ -431,13 +455,19 @@ def set_params(self, **parameters: Any) -> Self:
"""
parameter_copy = dict(parameters)
if "descriptors" in parameter_copy:
self.patterns = parameter_copy.pop("descriptors")
self.descriptors = parameter_copy.pop("descriptors")
super().set_params(**parameter_copy)
return self

def pretransform_single(self, value: RDKitMol) -> OptionalMol:
"""Invalidate or validate molecule based on specified RDKit descriptors.
There are four possible scenarios:
- Mode = "any" & "keep_matches" = True: Needs to match at least one descriptor.
- Mode = "any" & "keep_matches" = False: Must not match any descriptor.
- Mode = "all" & "keep_matches" = True: Needs to match all descriptors.
- Mode = "all" & "keep_matches" = False: Must not match all descriptors.
Parameters
----------
value: RDKitMol
Expand All @@ -450,49 +480,51 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
"""
for descriptor, (min_count, max_count) in self.descriptors.items():
descriptor_value = getattr(Descriptors, descriptor)(value)
if (min_count is None or descriptor_value >= min_count) and (
max_count is None or descriptor_value <= max_count
):

if _within_boundaries(min_count, max_count, descriptor_value):
# For "any" mode we can return early if a match is found
if self.mode == "any":
return (
value
if self.keep_matches
else InvalidInstance(
if not self.keep_matches:
value = InvalidInstance(
self.uuid,
f"Molecule contains forbidden descriptor {descriptor}.",
self.name,
)
)
return value
else:
# For "all" mode we can return early if a match is not found
if self.mode == "all":
return (
value
if not self.keep_matches
else InvalidInstance(
if self.keep_matches:
value = InvalidInstance(
self.uuid,
"Molecule does not match all required descriptors.",
f"Molecule does not contain required descriptor {descriptor}.",
self.name,
)
)
return value

# If this point is reached, no or all patterns were found
# If mode is "any", finishing the loop means no match was found
if self.mode == "any":
return (
value
if not self.keep_matches
else InvalidInstance(
if self.keep_matches:
value = InvalidInstance(
self.uuid,
"Molecule does not match any of the DescriptorsFilter descriptors.",
"Molecule does not match any of the required descriptors.",
self.name,
)
)
return (
value
if self.keep_matches
else InvalidInstance(
self.uuid,
"Molecule does not match all of the DescriptorsFilter descriptors.",
self.name,
)
)
# else: No match with forbidden descriptors was found, return original molecule
return value

if self.mode == "all":
if not self.keep_matches:
value = InvalidInstance(
self.uuid,
"Molecule matches all forbidden descriptors.",
self.name,
)
# else: All required descriptors were found, return original molecule
return value

raise ValueError(f"Invalid mode: {self.mode}")


class MixtureFilter(_MolToMolPipelineElement):
Expand Down

0 comments on commit 9fed198

Please sign in to comment.