From c8486c1f2d8df333c53b99728812ebf83a2597a6 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Mon, 1 May 2023 13:09:33 -0700 Subject: [PATCH] truncate data from right for megatron finetuhning Signed-off-by: Yang Zhang --- .../nlp/data/common/sequence_to_sequence_dataset.py | 12 ++++++++++-- .../language_modeling/megatron_finetune_model.py | 1 + 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/data/common/sequence_to_sequence_dataset.py b/nemo/collections/nlp/data/common/sequence_to_sequence_dataset.py index d4ca85289a6f..1f6b8104f473 100644 --- a/nemo/collections/nlp/data/common/sequence_to_sequence_dataset.py +++ b/nemo/collections/nlp/data/common/sequence_to_sequence_dataset.py @@ -43,6 +43,7 @@ def __init__( add_bos_to_input: bool = True, add_eos_to_input: bool = True, replace_bos_with_pad: bool = False, + truncate_from_right: bool = True, # scrolls dataset is truncated from right ): super().__init__() self.src_file_name = src_file_name @@ -54,6 +55,7 @@ def __init__( self.add_bos_to_input = add_bos_to_input self.add_eos_to_input = add_eos_to_input self.replace_bos_with_pad = replace_bos_with_pad + self.truncate_from_right = truncate_from_right assert self.max_src_seq_length > 0 assert self.max_tgt_seq_length > 0 self._check_files_exist() @@ -94,9 +96,15 @@ def _get_examples(self): ) # Truncate to max sequence length. if len(src) > self.max_src_seq_length: - src = src[-self.max_src_seq_length + 1 :] + if self.truncate_from_right: + src = src[: self.max_src_seq_length] + else: + src = src[-self.max_src_seq_length + 1 :] if len(tgt) > self.max_tgt_seq_length: - tgt = tgt[-self.max_tgt_seq_length + 1 :] + if self.truncate_from_right: + tgt = tgt[: self.max_tgt_seq_length] + else: + tgt = tgt[-self.max_tgt_seq_length + 1 :] self.examples.append({'src': src, 'tgt': tgt}) logging.info(f'Dataset Length : {len(self.examples)}') diff --git a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py index 347f6e665e6f..d83501e3dd65 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py @@ -601,6 +601,7 @@ def _build_train_dataset(self, data_cfg): add_bos_to_input=data_cfg.get('add_bos_to_input', True), add_eos_to_input=data_cfg.get('add_eos_to_input', True), replace_bos_with_pad=data_cfg.get('replace_bos_with_pad', False), + truncate_from_right=data_cfg.get('truncate_from_right', True), ) datasets.append(dataset)