Skip to content

Commit

Permalink
T5 prompt learning fixes missing from r.11.0 merge (NVIDIA#5075) (NVI…
Browse files Browse the repository at this point in the history
…DIA#5101)

* Fix special tokens

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Empty

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: David <amosalla@asu.edu>

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: David <amosalla@asu.edu>
Co-authored-by: Eric Harper <complex451@gmail.com>
Signed-off-by: 1-800-bad-code <shane.carroll@utsa.edu>
  • Loading branch information
4 people authored and 1-800-BAD-CODE committed Nov 13, 2022
1 parent fbdce6b commit 9af1daa
Showing 1 changed file with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -511,11 +511,15 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
idx = pred.index(self.tokenizer.eos_id)
pred = pred[:idx]

special_token_ids = (
self.tokenizer.special_token_to_id.values()
if hasattr(self.tokenizer, 'special_token_to_id')
else self.tokenizer.tokenizer.additional_special_tokens_ids
)
pred = [
id
for id in pred
if id not in self.tokenizer.tokenizer.additional_special_tokens_ids
and id not in self.tokenizer.text_to_ids(T5Sentinel.FIRST.value)
if id not in special_token_ids and id not in self.tokenizer.text_to_ids(T5Sentinel.FIRST.value)
] # delete the sentinel token at the beginning of prediction

pred = self.tokenizer.ids_to_text(pred)
Expand All @@ -532,8 +536,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
label = [
id
for id in label
if id not in self.tokenizer.tokenizer.additional_special_tokens_ids
and id not in self.tokenizer.text_to_ids(T5Sentinel.FIRST.value)
if id not in special_token_ids and id not in self.tokenizer.text_to_ids(T5Sentinel.FIRST.value)
] # delete the sentinel token at the beginning of label

label = self.tokenizer.ids_to_text(label)
Expand Down

0 comments on commit 9af1daa

Please sign in to comment.