From 9af1daa698ce1e9e494820103c86df16bbe17c9a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 7 Oct 2022 10:00:17 -0600 Subject: [PATCH] T5 prompt learning fixes missing from r.11.0 merge (#5075) (#5101) * Fix special tokens Signed-off-by: MaximumEntropy * Fix Signed-off-by: MaximumEntropy * Empty Signed-off-by: MaximumEntropy Signed-off-by: MaximumEntropy Co-authored-by: David Signed-off-by: MaximumEntropy Co-authored-by: Sandeep Subramanian Co-authored-by: David Co-authored-by: Eric Harper Signed-off-by: 1-800-bad-code --- .../megatron_t5_prompt_learning_model.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py index 512faaafe6cd..57d5e70a405d 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py @@ -511,11 +511,15 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A idx = pred.index(self.tokenizer.eos_id) pred = pred[:idx] + special_token_ids = ( + self.tokenizer.special_token_to_id.values() + if hasattr(self.tokenizer, 'special_token_to_id') + else self.tokenizer.tokenizer.additional_special_tokens_ids + ) pred = [ id for id in pred - if id not in self.tokenizer.tokenizer.additional_special_tokens_ids - and id not in self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + if id not in special_token_ids and id not in self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) ] # delete the sentinel token at the beginning of prediction pred = self.tokenizer.ids_to_text(pred) @@ -532,8 +536,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A label = [ id for id in label - if id not in self.tokenizer.tokenizer.additional_special_tokens_ids - and id not in self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + if id not in special_token_ids and id not in self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) ] # delete the sentinel token at the beginning of label label = self.tokenizer.ids_to_text(label)