Skip to content

Commit

Permalink
update and fix a bug in RNA Ligand and add the possibility to do aggr…
Browse files Browse the repository at this point in the history
…egation in NodeAttributeFilter
  • Loading branch information
wisskarrou committed Jan 8, 2025
1 parent f3b3363 commit afc14a3
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/rnaglib/tasks/RNA_Ligand/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
data = pd.read_csv(os.path.join(os.path.dirname(__file__), "data/gmsm_dataset.csv"))

# Creating task
ta = LigandIdentification('RNA-Ligand', data, recompute=True, filter_by_size=True, filter_by_resolution=True)
ta = LigandIdentification('RNA-Ligand', data, recompute=True, filter_by_size=True, filter_by_resolution=True, in_memory=False)

# Splitting dataset
print("Splitting Dataset")
Expand Down
27 changes: 15 additions & 12 deletions src/rnaglib/tasks/RNA_Ligand/ligand_identity.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os

import pandas as pd
import numpy as np

from rnaglib.tasks import RNAClassificationTask
from rnaglib.data_loading import RNADataset
from rnaglib.encoders import IntEncoder
from rnaglib.transforms import FeaturesComputer, AnnotatorFromDict, PartitionFromDict, ResidueNameFilter, RBPTransform, ComposeFilters, ResidueAttributeFilter
from rnaglib.transforms import FeaturesComputer, AnnotatorFromDict, PartitionFromDict, ResidueNameFilter, RBPTransform, ComposeFilters, ResidueAttributeFilter, RNAAttributeFilter
from rnaglib.utils import dump_json


Expand Down Expand Up @@ -48,24 +49,26 @@ def process(self):
# Instantiate transforms to apply
nt_partition = PartitionFromDict(partition_dict=self.bp_dict)
annotator = AnnotatorFromDict(annotation_dict=self.ligands_dict, name="ligand_code")
#protein_content_annotator = RBPTransform(structures_dir=dataset.structures_path, protein_number_annotations=False, distances=[4.,6.,8.])
protein_content_annotator = RBPTransform(structures_dir=dataset.structures_path, protein_number_annotations=False, distances=[4.,6.,8.])

# Run through database, applying our filters
all_binding_pockets = []
os.makedirs(self.dataset_path, exist_ok=True)
for rna in dataset:
if filters.forward(rna):
for binding_pocket_dict in nt_partition(rna):
annotated_binding_pocket = annotator(binding_pocket_dict)["rna"]
#annotated_binding_pocket = protein_content_annotator(annotated_binding_pocket_dict)["rna"]
if self.in_memory:
all_binding_pockets.append(annotated_binding_pocket)
else:
all_binding_pockets.append(annotated_binding_pocket.name)
dump_json(
os.path.join(self.dataset_path, f"{annotated_binding_pocket.name}.json"),
annotated_binding_pocket,
)
annotated_binding_pocket_dict = annotator(binding_pocket_dict)
annotated_binding_pocket = protein_content_annotator(annotated_binding_pocket_dict)
protein_content_filter = ResidueAttributeFilter(attribute="protein_content_8.0", aggregation_mode="aggfunc", value_checker=lambda x: x<10, aggfunc = np.mean)
if protein_content_filter.forward(annotated_binding_pocket):
if self.in_memory:
all_binding_pockets.append(annotated_binding_pocket["rna"])
else:
all_binding_pockets.append(annotated_binding_pocket["rna"].name)
dump_json(
os.path.join(self.dataset_path, f"""{annotated_binding_pocket["rna"].name}.json"""),
annotated_binding_pocket["rna"],
)
if self.in_memory:
dataset = RNADataset(rnas=all_binding_pockets)
else:
Expand Down
24 changes: 20 additions & 4 deletions src/rnaglib/transforms/filter/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@ class ResidueAttributeFilter(FilterTransform):
"""Reject RNAs that lack a certain annotation at the whole residue-level.
:param attribute: which node-level attribute to look for.
:param aggregation_mode: str (either "aggfunc" or "min_valid"); if set to "aggfunc", keeps an RNA if the output of
the aggregation function of the residue attribute at the RNA level passes the value_checker; if set to "min_valid",
keeps an RNA if more than min_valid nodes pass the value_checker
:param value_checker: function with accepts the value of the desired attribute and returns True/False
:param min_valid: minium number of valid nodes that pass the filter for keeping the RNA.
:param aggfunc: function to aggregate the residue labels at the RNA level (only if aggregarion_mode is "aggfunc")
:param min_valid: minium number of valid nodes that pass the filter for keeping the RNA. (only if aggregation_mode
is "min_valid")
Example
Expand All @@ -88,30 +93,41 @@ class ResidueAttributeFilter(FilterTransform):
def __init__(
self,
attribute: str,
aggregation_mode: str = "min_valid",
value_checker: Callable = None,
min_valid: int = 1,
aggfunc: Callable = None,
**kwargs,
):
self.attribute = attribute
self.aggregation_mode = aggregation_mode
self.min_valid = min_valid
self.aggfunc = aggfunc
self.value_checker = value_checker
super().__init__(**kwargs)
pass

def forward(self, data: dict):
n_valid = 0
g = data["rna"]
if self.aggregation_mode=="aggfunc":
vals_list = []
for node, ndata in g.nodes(data=True):
try:
val = ndata[self.attribute]
except KeyError:
continue
else:
if self.value_checker(val):
if self.aggregation_mode=="min_valid" and self.value_checker(val):
n_valid += 1
if n_valid >= self.min_valid:
elif self.aggregation_mode=="aggfunc":
vals_list.append(val)
if self.aggregation_mode=="min_valid" and n_valid >= self.min_valid:
return True
return False
if self.aggregation_mode=="min_valid":
return False
else:
return self.aggfunc(vals_list)

class ResidueNameFilter(FilterTransform):
def __init__(
Expand Down

0 comments on commit afc14a3

Please sign in to comment.