diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py index ad5ec0634c..09526572ef 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/datamodule.py @@ -79,7 +79,7 @@ def __getitem__(self, index: int) -> BertSample: sequence = self.sequences[index] tokenized_sequence = self._tokenize(sequence) - label = tokenized_sequence if len(self.labels) == 0 else self.labels[index] + label = tokenized_sequence if len(self.labels) == 0 else torch.Tensor([self.labels[index]]) # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids)) @@ -108,7 +108,7 @@ def load_data(self, csv_path: str | os.PathLike) -> Tuple[Sequence, Sequence]: df = pd.read_csv(csv_path) sequences = df["sequences"].tolist() - if "label" in df.columns: + if "labels" in df.columns: labels = df["labels"].tolist() else: labels = []