From 4d44f5d92f80f937d0f603ff9203346227812b19 Mon Sep 17 00:00:00 2001 From: Holger Roth <6304754+holgerroth@users.noreply.github.com> Date: Wed, 18 Dec 2024 14:02:28 -0500 Subject: [PATCH] Fix csv dataset (#543) Fix typo in label key. ## Summary Check for loading labels should use "labels" instead of "label". ## Details Fixed the typo ## Usage n/a ## Testing no new tests --------- Signed-off-by: Holger Roth --- .../src/bionemo/esm2/model/finetune/datamodule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py index ad5ec0634c..09526572ef 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py @@ -79,7 +79,7 @@ def __getitem__(self, index: int) -> BertSample: sequence = self.sequences[index] tokenized_sequence = self._tokenize(sequence) - label = tokenized_sequence if len(self.labels) == 0 else self.labels[index] + label = tokenized_sequence if len(self.labels) == 0 else torch.Tensor([self.labels[index]]) # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids)) @@ -108,7 +108,7 @@ def load_data(self, csv_path: str | os.PathLike) -> Tuple[Sequence, Sequence]: df = pd.read_csv(csv_path) sequences = df["sequences"].tolist() - if "label" in df.columns: + if "labels" in df.columns: labels = df["labels"].tolist() else: labels = []