diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py index 3d0c29673c83..69cd485b0ca5 100755 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_learning_dataset.py @@ -364,6 +364,7 @@ def collate_fn(self, batch, tp_workers=0): def pad_batch_and_build_loss_mask(self, input_ids, batch_max, answer_starts): """ Pad input_ids in batch to max batch length while building loss mask """ batch_loss_masks = [] + padded_input_ids = [] for ids, answer_start_idx in zip(input_ids, answer_starts): if answer_start_idx is not None: # Loss mask where answer tokens are 1.0 and all other tokens are 0.0 @@ -375,17 +376,19 @@ def pad_batch_and_build_loss_mask(self, input_ids, batch_max, answer_starts): # Pad to max length input_length = len(ids) padding_length = batch_max - input_length - ids.extend([self.pad_token_id] * padding_length) + pad_extend = [self.pad_token_id] * padding_length + ids = ids + pad_extend + padded_input_ids.append(ids) # Account for padding in loss mask loss_mask.extend([0.0] * padding_length) batch_loss_masks.append(torch.tensor(loss_mask, dtype=torch.float)) # Make into torch tensors - input_ids = torch.tensor(input_ids, dtype=torch.long) + padded_input_ids = torch.tensor(padded_input_ids, dtype=torch.long) batch_loss_masks = torch.stack(batch_loss_masks) - return input_ids, batch_loss_masks + return padded_input_ids, batch_loss_masks def inference_collate_fn(self, batch): """ diff --git a/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py index 0f39cd8e05c9..2858d9d183df 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/t5_prompt_learning_dataset.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import enum import json import torch