Skip to content

Commit

Permalink
Minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Feb 11, 2025
1 parent 2bc74d6 commit e40cd54
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
8 changes: 4 additions & 4 deletions configs/downstream/test_lm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ model:
learning_rate: 0.001
batch_size: 256
optimizer: Adam
suffix: _glylm_bpe_glyles_25_t6_20
suffix: _bpe_glyles_25_t6_20
- name: glylm
token_file: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/unique/GlyLM/bpe_glyles_5000.pkl
model_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/unique/GlyLM/bpe_glyles_5000_t6/checkpoint-185120
Expand All @@ -28,7 +28,7 @@ model:
learning_rate: 0.001
batch_size: 256
optimizer: Adam
suffix: _glylm_bpe_glyles_50_t6_20
suffix: _bpe_glyles_50_t6_20
- name: glylm
token_file: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/unique/GlyLM/bpe_glyles_7500.pkl
model_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/unique/GlyLM/bpe_glyles_7500_t6/checkpoint-185120
Expand All @@ -37,7 +37,7 @@ model:
learning_rate: 0.001
batch_size: 256
optimizer: Adam
suffix: _glylm_bpe_glyles_75_t6_20
suffix: _bpe_glyles_75_t6_20
- name: glylm
token_file: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/unique/GlyLM/bpe_glyles_10000.pkl
model_dir: /scratch/SCRATCH_SAS/roman/Gothenburg/GIFFLAR/unique/GlyLM/bpe_glyles_10000_t6/checkpoint-185120
Expand All @@ -46,4 +46,4 @@ model:
learning_rate: 0.001
batch_size: 256
optimizer: Adam
suffix: _glylm_bpe_glyles_100_t6_20
suffix: _bpe_glyles_100_t6_20
18 changes: 10 additions & 8 deletions gifflar/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ def get_immunogenicity(root: Path | str) -> Path:
Returns:
The filepath of the processed immunogenicity data.
"""
if not (p := (root / Path("immunogenicity.tsv"))).exists():
root = Path(root)
if not (p := (root / "immunogenicity.tsv")).exists():
# Download the data
urllib.request.urlretrieve("https://torchglycan.s3.us-east-2.amazonaws.com/downstream/glycan_immunogenicity.csv", p)
urllib.request.urlretrieve("https://torchglycan.s3.us-east-2.amazonaws.com/downstream/glycan_immunogenicity.csv", p.with_suffix(".csv"))

# Process the data and remove unnecessary columns
df = pd.read_csv("immunogenicity.csv")[["glycan", "immunogenicity"]]
df = pd.read_csv(p.with_suffix(".csv"))[["glycan", "immunogenicity"]]
df.rename(columns={"glycan": "IUPAC"}, inplace=True)
df.dropna(inplace=True)

Expand All @@ -118,7 +119,7 @@ def get_immunogenicity(root: Path | str) -> Path:

df.drop("immunogenicity", axis=1, inplace=True)
df.to_csv(p, sep="\t", index=False)
with open("immunogenicity_classes.tsv", "w") as f:
with open(root / "immunogenicity_classes.tsv", "w") as f:
for n, i in classes.items():
print(n, i, sep="\t", file=f)
return p
Expand All @@ -134,9 +135,10 @@ def get_glycosylation(root: Path | str) -> Path:
Returns:
The filepath of the processed glycosylation data.
"""
if not (p := root / Path("glycosylation.tsv")).exists():
urllib.request.urlretrieve("https://torchglycan.s3.us-east-2.amazonaws.com/downstream/glycan_properties.csv", p)
df = pd.read_csv("glycosylation.csv")[["glycan", "link"]]
root = Path(root)
if not (p := root / "glycosylation.tsv").exists():
urllib.request.urlretrieve("https://torchglycan.s3.us-east-2.amazonaws.com/downstream/glycan_properties.csv", p.with_suffix(".csv"))
df = pd.read_csv(p.with_suffix(".csv"))[["glycan", "link"]]
df.rename(columns={"glycan": "IUPAC"}, inplace=True)
df.dropna(inplace=True)

Expand All @@ -146,7 +148,7 @@ def get_glycosylation(root: Path | str) -> Path:

df.drop("link", axis=1, inplace=True)
df.to_csv(p, sep="\t", index=False)
with open("glycosylation_classes.tsv", "w") as f:
with open(root / "glycosylation_classes.tsv", "w") as f:
for n, i in classes.items():
print(n, i, sep="\t", file=f)
return p
Expand Down
14 changes: 8 additions & 6 deletions gifflar/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from gifflar.data.utils import GlycanStorage


PROCESS_CHUNK_SIZE = 100_000

class GlycanOnDiskDataset(OnDiskDataset):
def __init__(
self,
Expand Down Expand Up @@ -190,7 +192,7 @@ def process(self) -> None:
gs = GlycanStorage(Path(self.root).parent)
with open(self.filename, "r") as glycans:
for i, line in enumerate(glycans.readlines()):
if i % 1000 == 0:
if i % PROCESS_CHUNK_SIZE == 0:
self.process_(data, final=False)
del data
data = []
Expand Down Expand Up @@ -278,7 +280,7 @@ def process(self) -> None:
gs = GlycanStorage(Path(self.root).parent)
data = []
for i, (_, row) in tqdm(enumerate(df.iterrows())):
if i % 1000 == 0:
if i % PROCESS_CHUNK_SIZE == 0:
self.process_(data, path_idx=self.splits[self.split], final=False)
del data
data = []
Expand Down Expand Up @@ -342,7 +344,7 @@ def process_pkl(self) -> None:
gs = GlycanStorage(Path(self.root).parent)
data = []
for i, (lectin_id, glycan_id, value, split) in tqdm(enumerate(inter)):
if i % 1000 == 0:
if i % PROCESS_CHUNK_SIZE == 0:
self.process_(data, path_idx=self.splits[split], final=False)
del data
data = []
Expand All @@ -364,7 +366,7 @@ def process_csv(self, sep: str) -> None:
gs = GlycanStorage(Path(self.root).parent)
data = []
for i, (_, row) in tqdm(enumerate(inter.iterrows())):
if i % 1000 == 0:
if i % PROCESS_CHUNK_SIZE == 0:
self.process_(data, path_idx=self.splits[split], final=False)
del data
data = []
Expand Down Expand Up @@ -400,7 +402,7 @@ def process_pkl(self) -> None:
gs = GlycanStorage(Path(self.root).parent)
data = []
for i, (lectin, glycan, glycan_val, decoy, decoy_val, split) in tqdm(enumerate(lgis)):
if i % 1000 == 0:
if i % PROCESS_CHUNK_SIZE == 0:
self.process_(data, path_idx=self.splits[split], final=False)
del data
data = []
Expand Down Expand Up @@ -431,7 +433,7 @@ def process_csv(self, sep):
gs = GlycanStorage(Path(self.root).parent)
data = []
for i, (_, row) in tqdm(enumerate(inter.iterrows())):
if i % 1000 == 0:
if i % PROCESS_CHUNK_SIZE == 0:
self.process_(data, path_idx=self.splits[split], final=False)
del data
data = []
Expand Down
2 changes: 1 addition & 1 deletion gifflar/model/glylm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, token_file, model_dir, hidden_dim: int, *args, **kwargs):
tokenizer.load(token_file)
glycan_lm = EsmModel.from_pretrained(model_dir)
self.encoder = lambda x: pipeline(tokenizer, glycan_lm, x)
self.head, self.loss, self.metrics = get_prediction_head(hidden_dim, 1, "regression")
self.head, self.loss, self.metrics = get_prediction_head(hidden_dim, self.output_dim, self.task)

def forward(self, batch: HeteroDataBatch) -> dict[str, torch.Tensor]:
"""
Expand Down

0 comments on commit e40cd54

Please sign in to comment.