Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support in external sample mapping for Megatron datasets #6483

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import os
import subprocess
import time
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -1255,6 +1256,7 @@ def get_samples_mapping(
name,
binary_head,
index_mapping_dir: str = None,
samples_mapping: Any = None,
):
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""

Expand All @@ -1280,8 +1282,8 @@ def get_samples_mapping(
indexmap_filename += '_{}s'.format(seed)
indexmap_filename += '.npy'

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
# Build the indexed mapping if not exist and not provided externally.
if samples_mapping is None and torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
# Fake index mapping if missing
if (getattr(indexed_dataset, 'doc_idx', None) is None) and (getattr(indexed_dataset, 'sizes', None) is None):
make_indexed_dataset_compatibility(indexed_dataset)
Expand Down Expand Up @@ -1334,15 +1336,16 @@ def get_samples_mapping(
torch.distributed.get_world_size()
// torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group())
)
# Load indexed dataset.
logging.info(' > loading indexed mapping from {}'.format(indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time))
logging.info(' total number of samples: {}'.format(samples_mapping.shape[0]))
# Load indexed dataset if not given externally.
if samples_mapping is None:
logging.info(' > loading indexed mapping from {}'.format(indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
logging.info(' loaded indexed file in {:3.3f} seconds'.format(time.time() - start_time))
logging.info(' total number of samples: {}'.format(samples_mapping.shape[0]))

# Deallocate temporary numpy arrays that were created for `get_samples_mapping()` when needed
if hasattr(indexed_dataset, 'doc_idx') and hasattr(indexed_dataset, 'sizes'):
deallocate_indexed_dataset_memory(indexed_dataset)
# Deallocate temporary numpy arrays that were created for `get_samples_mapping()` when needed
if hasattr(indexed_dataset, 'doc_idx') and hasattr(indexed_dataset, 'sizes'):
deallocate_indexed_dataset_memory(indexed_dataset)

return samples_mapping