Skip to content

Commit

Permalink
Fix csv dataset (#543)
Browse files Browse the repository at this point in the history
Fix typo in label key.

## Summary
Check for loading labels should use "labels" instead of "label". 

## Details
Fixed the typo

## Usage
n/a

## Testing
no new tests

---------

Signed-off-by: Holger Roth <hroth@nvidia.com>
  • Loading branch information
holgerroth authored Dec 18, 2024
1 parent a8934a0 commit 4d44f5d
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __getitem__(self, index: int) -> BertSample:
sequence = self.sequences[index]
tokenized_sequence = self._tokenize(sequence)

label = tokenized_sequence if len(self.labels) == 0 else self.labels[index]
label = tokenized_sequence if len(self.labels) == 0 else torch.Tensor([self.labels[index]])
# Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))

Expand Down Expand Up @@ -108,7 +108,7 @@ def load_data(self, csv_path: str | os.PathLike) -> Tuple[Sequence, Sequence]:
df = pd.read_csv(csv_path)
sequences = df["sequences"].tolist()

if "label" in df.columns:
if "labels" in df.columns:
labels = df["labels"].tolist()
else:
labels = []
Expand Down

0 comments on commit 4d44f5d

Please sign in to comment.