From 7ec5686dd39c69ac2efeb6276e5bb488bb22f070 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 25 Apr 2023 09:44:21 -0400 Subject: [PATCH] 1. Added external index sample. (#6462) (#6483) Signed-off-by: Micha Livne Co-authored-by: Micha Livne --- .../megatron/dataset_utils.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py index d1f0718a6abd..775ac271d5b2 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -34,6 +34,7 @@ import os import subprocess import time +from typing import Any import numpy as np import torch @@ -1255,6 +1256,7 @@ def get_samples_mapping( name, binary_head, index_mapping_dir: str = None, + samples_mapping: Any = None, ): """Get a list that maps a sample index to a starting sentence index, end sentence index, and length""" @@ -1280,8 +1282,8 @@ def get_samples_mapping( indexmap_filename += '_{}s'.format(seed) indexmap_filename += '.npy' - # Build the indexed mapping if not exist. - if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename): + # Build the indexed mapping if not exist and not provided externally. + if samples_mapping is None and torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename): # Fake index mapping if missing if (getattr(indexed_dataset, 'doc_idx', None) is None) and (getattr(indexed_dataset, 'sizes', None) is None): make_indexed_dataset_compatibility(indexed_dataset) @@ -1334,15 +1336,16 @@ def get_samples_mapping( torch.distributed.get_world_size() // torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group()) ) - # Load indexed dataset. - logging.info(' > loading indexed mapping from {}'.format(indexmap_filename)) - start_time = time.time() - samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') - logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time)) - logging.info(' total number of samples: {}'.format(samples_mapping.shape[0])) + # Load indexed dataset if not given externally. + if samples_mapping is None: + logging.info(' > loading indexed mapping from {}'.format(indexmap_filename)) + start_time = time.time() + samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r') + logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time)) + logging.info(' total number of samples: {}'.format(samples_mapping.shape[0])) - # Deallocate temporary numpy arrays that were created for `get_samples_mapping()` when needed - if hasattr(indexed_dataset, 'doc_idx') and hasattr(indexed_dataset, 'sizes'): - deallocate_indexed_dataset_memory(indexed_dataset) + # Deallocate temporary numpy arrays that were created for `get_samples_mapping()` when needed + if hasattr(indexed_dataset, 'doc_idx') and hasattr(indexed_dataset, 'sizes'): + deallocate_indexed_dataset_memory(indexed_dataset) return samples_mapping