Skip to content

Commit

Permalink
Merge pull request NVIDIA#6 from yzhang123/truncate_right
Browse files Browse the repository at this point in the history
truncate data from right for megatron finetuhning
  • Loading branch information
sam1373 authored May 1, 2023
2 parents a373b37 + c8486c1 commit df7c214
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
12 changes: 10 additions & 2 deletions nemo/collections/nlp/data/common/sequence_to_sequence_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
add_bos_to_input: bool = True,
add_eos_to_input: bool = True,
replace_bos_with_pad: bool = False,
truncate_from_right: bool = True, # scrolls dataset is truncated from right
):
super().__init__()
self.src_file_name = src_file_name
Expand All @@ -54,6 +55,7 @@ def __init__(
self.add_bos_to_input = add_bos_to_input
self.add_eos_to_input = add_eos_to_input
self.replace_bos_with_pad = replace_bos_with_pad
self.truncate_from_right = truncate_from_right
assert self.max_src_seq_length > 0
assert self.max_tgt_seq_length > 0
self._check_files_exist()
Expand Down Expand Up @@ -94,9 +96,15 @@ def _get_examples(self):
)
# Truncate to max sequence length.
if len(src) > self.max_src_seq_length:
src = src[-self.max_src_seq_length + 1 :]
if self.truncate_from_right:
src = src[: self.max_src_seq_length]
else:
src = src[-self.max_src_seq_length + 1 :]
if len(tgt) > self.max_tgt_seq_length:
tgt = tgt[-self.max_tgt_seq_length + 1 :]
if self.truncate_from_right:
tgt = tgt[: self.max_tgt_seq_length]
else:
tgt = tgt[-self.max_tgt_seq_length + 1 :]
self.examples.append({'src': src, 'tgt': tgt})

logging.info(f'Dataset Length : {len(self.examples)}')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def _build_train_dataset(self, data_cfg):
add_bos_to_input=data_cfg.get('add_bos_to_input', True),
add_eos_to_input=data_cfg.get('add_eos_to_input', True),
replace_bos_with_pad=data_cfg.get('replace_bos_with_pad', False),
truncate_from_right=data_cfg.get('truncate_from_right', True),
)
datasets.append(dataset)

Expand Down

0 comments on commit df7c214

Please sign in to comment.