Skip to content

Commit

Permalink
Minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Jan 2, 2025
1 parent e9c4553 commit bae5b9c
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 35 deletions.
22 changes: 11 additions & 11 deletions configs/lgi/all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,25 @@ seed: 42
#origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_logs
origin: /home/rjo21/Desktop/GIFFLAR/lgi_data_20.pkl
origin: /home/rjo21/Desktop/GIFFLAR/lgi_data_full.pkl
model:
glycan_encoder:
- name: gifflar
feat_dim: 128
hidden_dim: 1024
num_layers: 8
pooling: global_mean
#- name: sweetnet
# feat_dim: 128
# hidden_dim: 1024
# num_layers: 16
- name: sweetnet
feat_dim: 128
hidden_dim: 1024
num_layers: 16
lectin_encoder:
- name: ESM
layer_num: 33
- name: Ankh
layer_num: 48
- name: ProtBert
layer_num: 30
#- name: ESM
# layer_num: 33
#- name: Ankh
# layer_num: 48
#- name: ProtBert
# layer_num: 30
- name: ProstT5
layer_num: 24
batch_size: 256
Expand Down
12 changes: 6 additions & 6 deletions configs/lgi/test.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
seed: 42
root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data
logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs
origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl
#root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data
#logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_logs
#origin: /home/rjo21/Desktop/GIFFLAR/lgi_data.pkl
#root_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_data
#logs_dir: /home/daniel/Data1/roman/GIFFLAR/lgi_logs
#origin: /home/daniel/Desktop/GIFFLAR/lgi_data_full.pkl
root_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data
logs_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_logs
origin: /home/rjo21/Desktop/GIFFLAR/lgi_data_20.pkl
model:
glycan_encoder:
name: gifflar
Expand Down
21 changes: 12 additions & 9 deletions experiments/lectinoracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def processed_file_names(self):


def get_ds(dl, split_idx: int):
ds = LGI_OnDiskDataset(root="/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_data", path_idx=split_idx)
ds = LGI_OnDiskDataset(root="/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_full", path_idx=split_idx)
data = []
for x in tqdm(dl):
data.append(Data(
Expand Down Expand Up @@ -70,22 +70,22 @@ def collate_lgi(data):
)

datamodule = LGI_GDM(
root="/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data", filename="/home/rjo21/Desktop/GIFFLAR/lgi_data_20.pkl", hash_code="8b34af2a",
root="/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/lgi_data", filename="/home/rjo21/Desktop/GIFFLAR/lgi_data_full.pkl", hash_code="8b34af2a",
batch_size=1, transform=None, pre_transform={"GIFFLARTransform": "", "SweetNetTransform": ""},
)

#get_ds(datamodule.train_dataloader(), 0)
#get_ds(datamodule.val_dataloader(), 1)
#get_ds(datamodule.test_dataloader(), 2)
get_ds(datamodule.train_dataloader(), 0)
get_ds(datamodule.val_dataloader(), 1)
get_ds(datamodule.test_dataloader(), 2)

train_set = LGI_OnDiskDataset("/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_data", path_idx=0)
val_set = LGI_OnDiskDataset("/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_data", path_idx=1)
train_set = LGI_OnDiskDataset("/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_full", path_idx=0)
val_set = LGI_OnDiskDataset("/scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/glycowork_full", path_idx=1)

model = prep_model("LectinOracle", num_classes=1)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

m = train_model(
m, met = train_model(
model=model,
dataloaders={"train": torch.utils.data.DataLoader(train_set, batch_size=128, collate_fn=collate_lgi),
"val": torch.utils.data.DataLoader(val_set, batch_size=128, collate_fn=collate_lgi)},
Expand All @@ -100,5 +100,8 @@ def collate_lgi(data):

import pickle

with open("lectinoracle_metrics.pkl", "wb") as f:
with open("lectinoracle_full_model.pkl", "wb") as f:
pickle.dump(m, f)
with open("lectinoracle_full_metrics.pkl", "wb") as f:
pickle.dump(met, f)

10 changes: 8 additions & 2 deletions experiments/lgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,20 @@ def query(self, aa_seq: str) -> torch.Tensor:
try:
self.data[aa_seq] = self.encoder(aa_seq)
except Exception as e:
# print(e)
print(e)
self.data[aa_seq] = None

return self.data[aa_seq]

def batch_query(self, aa_seqs) -> torch.Tensor:
# print([self.query(aa_seq) for aa_seq in aa_seqs])
return torch.stack([self.query(aa_seq) for aa_seq in aa_seqs])
results = [self.query(aa_seq) for aa_seq in aa_seqs]
dummy = None
for x in results:
if x is not None:
dummy = torch.zeros_like(x)
break
return torch.stack([dummy if res is None else res for res in results])


class LGI_Model(LightningModule):
Expand Down
6 changes: 3 additions & 3 deletions experiments/train_lgi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from argparse import ArgumentParser
import time
Expand Down Expand Up @@ -51,7 +51,7 @@ def train(**kwargs):

datamodule = LGI_GDM(
root=kwargs["root_dir"], filename=kwargs["origin"], hash_code=kwargs["hash"],
batch_size=kwargs["model"].get("batch_size", 1), transform=None,
batch_size=kwargs["model"].get("batch_size", 1), transform=None, num_workers=12,
pre_transform=get_pretransforms("", **(kwargs["pre-transforms"] or {})),
)

Expand All @@ -74,7 +74,7 @@ def train(**kwargs):

trainer = Trainer(
callbacks=[
ModelCheckpoint(dirpath=Path(kwargs["logs_dir"]) / "full", monitor="val/loss"),
ModelCheckpoint(dirpath=Path(kwargs["logs_dir"]) / f"LGI_{glycan_model_name}{lectin_model_name}" / "weights", monitor="val/loss"),
RichProgressBar(),
RichModelSummary(),
],
Expand Down
10 changes: 6 additions & 4 deletions gifflar/data/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class GlycanDataModule(LightningDataModule):
"""DataModule holding datasets for Glycan-specific training"""

def __init__(self, batch_size: int = 128, num_workers: int = 1, **kwargs: Any):
def __init__(self, batch_size: int = 128, num_workers: int = 0, **kwargs: Any):
"""
Initialize the DataModule with a given batch size.
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(
train_frac: float = 0.8,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
num_workers: int = 0,
**kwargs: Any,
):
"""
Expand All @@ -94,7 +95,7 @@ def __init__(
**kwargs: Additional arguments to pass to the Pretrain
"""
root = Path(file_path).parent
super().__init__(batch_size)
super().__init__(batch_size, num_workers=num_workers)
ds = PretrainGDs(root=root, filename=file_path, hash_code=hash_code, transform=transform,
pre_transform=pre_transform, **kwargs)
self.train, self.val = torch.utils.data.dataset.random_split(ds, [train_frac, 1 - train_frac])
Expand All @@ -114,6 +115,7 @@ def __init__(
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
num_workers: int = 0,
**dataset_args: dict[str, Any],
):
"""
Expand All @@ -128,7 +130,7 @@ def __init__(
pre_transform: The pre-transform to apply to the data
**dataset_args: Additional arguments to pass to the DownstreamGDs
"""
super().__init__(batch_size)
super().__init__(batch_size, num_workers=num_workers)
self.train = self.ds_class(
root=root, filename=filename, split="train", hash_code=hash_code, transform=transform,
pre_transform=pre_transform, force_reload=force_reload, **dataset_args,
Expand All @@ -143,4 +145,4 @@ def __init__(
)

class LGI_GDM(DownsteamGDM):
ds_class = LGIDataset
ds_class = LGIDataset

0 comments on commit bae5b9c

Please sign in to comment.