From c04a60620834b3e8b96a7707d57a5ade5eeac948 Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Tue, 17 Dec 2024 21:04:18 -0500 Subject: [PATCH 1/2] fix csv dataset Signed-off-by: Holger Roth --- .../bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py index ad5ec0634c..19407749a6 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py @@ -108,7 +108,7 @@ def load_data(self, csv_path: str | os.PathLike) -> Tuple[Sequence, Sequence]: df = pd.read_csv(csv_path) sequences = df["sequences"].tolist() - if "label" in df.columns: + if "labels" in df.columns: labels = df["labels"].tolist() else: labels = [] From d4672b8b5e7357344559d6e168f4a87664a4baff Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Tue, 17 Dec 2024 21:28:41 -0500 Subject: [PATCH 2/2] increase dimensions of labels Signed-off-by: Holger Roth --- .../bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py index 19407749a6..09526572ef 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py @@ -79,7 +79,7 @@ def __getitem__(self, index: int) -> BertSample: sequence = self.sequences[index] tokenized_sequence = self._tokenize(sequence) - label = tokenized_sequence if len(self.labels) == 0 else self.labels[index] + label = tokenized_sequence if len(self.labels) == 0 else torch.Tensor([self.labels[index]]) # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))