Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewrite filter logic #71

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 68 additions & 36 deletions molpipeline/mol2mol/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,34 @@
from molpipeline.utils.value_conversions import count_value_to_tuple


def _within_boundaries(
lower_bound: Optional[float], upper_bound: Optional[float], value: float
) -> bool:
"""Check if a value is within the specified boundaries.

Boundaries given as None are ignored.

Parameters
----------
lower_bound: Optional[float]
Lower boundary.
upper_bound: Optional[float]
Upper boundary.
value: float
Value to check.

Returns
-------
bool
True if the value is within the boundaries, else False.
"""
if lower_bound is not None and value < lower_bound:
return False
if upper_bound is not None and value > upper_bound:
return False
return True


class ElementFilter(_MolToMolPipelineElement):
"""ElementFilter which removes molecules containing chemical elements other than specified.

Expand Down Expand Up @@ -227,9 +255,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
match_smarts = smarts_match.GetDescription()
all_matches = value.GetSubstructMatches(Chem.MolFromSmarts(match_smarts))
min_count, max_count = self.patterns[match_smarts]
if (min_count is None or len(all_matches) >= min_count) and (
max_count is None or len(all_matches) <= max_count
):
if _within_boundaries(min_count, max_count, len(all_matches)):
if self.mode == "any":
return (
value
Expand Down Expand Up @@ -290,9 +316,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
"""
for pattern, (min_count, max_count) in self.patterns.items():
all_matches = value.GetSubstructMatches(Chem.MolFromSmiles(pattern))
if (min_count is None or len(all_matches) >= min_count) and (
max_count is None or len(all_matches) <= max_count
):
if _within_boundaries(min_count, max_count, len(all_matches)):
if self.mode == "any":
return (
value
Expand Down Expand Up @@ -431,13 +455,19 @@ def set_params(self, **parameters: Any) -> Self:
"""
parameter_copy = dict(parameters)
if "descriptors" in parameter_copy:
self.patterns = parameter_copy.pop("descriptors")
self.descriptors = parameter_copy.pop("descriptors")
super().set_params(**parameter_copy)
return self

def pretransform_single(self, value: RDKitMol) -> OptionalMol:
"""Invalidate or validate molecule based on specified RDKit descriptors.

There are four possible scenarios:
- Mode = "any" & "keep_matches" = True: Needs to match at least one descriptor.
- Mode = "any" & "keep_matches" = False: Must not match any descriptor.
- Mode = "all" & "keep_matches" = True: Needs to match all descriptors.
- Mode = "all" & "keep_matches" = False: Must not match all descriptors.

Parameters
----------
value: RDKitMol
Expand All @@ -450,49 +480,51 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol:
"""
for descriptor, (min_count, max_count) in self.descriptors.items():
descriptor_value = getattr(Descriptors, descriptor)(value)
if (min_count is None or descriptor_value >= min_count) and (
max_count is None or descriptor_value <= max_count
):

if _within_boundaries(min_count, max_count, descriptor_value):
# For "any" mode we can return early if a match is found
if self.mode == "any":
return (
value
if self.keep_matches
else InvalidInstance(
if not self.keep_matches:
value = InvalidInstance(
self.uuid,
f"Molecule contains forbidden descriptor {descriptor}.",
self.name,
)
)
return value
else:
# For "all" mode we can return early if a match is not found
if self.mode == "all":
return (
value
if not self.keep_matches
else InvalidInstance(
if self.keep_matches:
value = InvalidInstance(
self.uuid,
"Molecule does not match all required descriptors.",
f"Molecule does not contain required descriptor {descriptor}.",
self.name,
)
)
return value

# If this point is reached, no or all patterns were found
# If mode is "any", finishing the loop means no match was found
if self.mode == "any":
return (
value
if not self.keep_matches
else InvalidInstance(
if self.keep_matches:
value = InvalidInstance(
self.uuid,
"Molecule does not match any of the DescriptorsFilter descriptors.",
"Molecule does not match any of the required descriptors.",
self.name,
)
)
return (
value
if self.keep_matches
else InvalidInstance(
self.uuid,
"Molecule does not match all of the DescriptorsFilter descriptors.",
self.name,
)
)
# else: No match with forbidden descriptors was found, return original molecule
return value

if self.mode == "all":
if not self.keep_matches:
value = InvalidInstance(
self.uuid,
"Molecule matches all forbidden descriptors.",
self.name,
)
# else: All required descriptors were found, return original molecule
return value

raise ValueError(f"Invalid mode: {self.mode}")


class MixtureFilter(_MolToMolPipelineElement):
Expand Down
Loading