Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FixTextMemMapDataset index file creation in multi-node setup #6768

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 46 additions & 20 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import torch

from nemo.core import Dataset
from nemo.utils import logging
from nemo.utils import AppState, logging

__all__ = ['TextMemMapDataset', 'CSVMemMapDataset', 'build_index_files']
__idx_version__ = '0.2' # index file version
__idx_suffix__ = 'idx' # index file suffix
__all__ = ["TextMemMapDataset", "CSVMemMapDataset", "build_index_files"]
__idx_version__ = "0.2" # index file version
__idx_suffix__ = "idx" # index file suffix


def _build_index_from_memdata(fn, newline_int):
Expand All @@ -40,7 +40,7 @@ def _build_index_from_memdata(fn, newline_int):
Returns a 1D array of ints.
"""
# use memmap to read file
mdata = np.memmap(fn, dtype=np.uint8, mode='r')
mdata = np.memmap(fn, dtype=np.uint8, mode="r")
# find newline positions
midx = np.where(mdata == newline_int)[0]
midx_dtype = midx.dtype
Expand Down Expand Up @@ -115,9 +115,10 @@ def __init__(

logging.info(f"Building data files")
# load all files into memmap
is_ditributed = torch.distributed.is_available() and torch.distributed.is_initialized()
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()

if not is_ditributed or (is_ditributed and torch.distributed.get_rank() == 0):
if not is_distributed or (is_distributed and torch.distributed.get_rank() == 0):
# Create index files on global rank 0.
build_index_files(
dataset_paths,
newline_int,
Expand All @@ -126,14 +127,39 @@ def __init__(
index_mapping_dir=index_mapping_dir,
)

if is_ditributed:
if is_distributed:
torch.distributed.barrier()

if is_distributed and AppState().local_rank == 0:
# If we are in a distributed multi-node set-up and index files are not stored on
# a shared filesystem, then the index files created on global rank 0 are only
# accessible to the workers on that node.
#
# Two cases may occur here:
#
# 1. case of a shared filesystem, or global_rank==0: the index files are present in
# the locally available filesystem, calling build_index_files() again is a no-op.
# 2. case of a non-shared filesystem, and global_rank>0: the index files are not
# present in the locally available filesystem, calling build_index_files() again
# will create them.
#
# Outcome in all cases: all nodes have access to the index files in their filesystem.
build_index_files(
dataset_paths,
newline_int,
workers=self._worker,
build_index_fn=build_index_fn,
index_mapping_dir=index_mapping_dir,
)

if is_distributed:
torch.distributed.barrier()
michalivne marked this conversation as resolved.
Show resolved Hide resolved

logging.info(f"Loading data files")
start_time = time.time()
mdata_midx_list = [self.load_file(fn, index_mapping_dir) for fn in self._files_list]
logging.info(
f'Time loading {len(mdata_midx_list)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}'
f"Time loading {len(mdata_midx_list)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
)

logging.info("Computing global indices")
Expand Down Expand Up @@ -224,34 +250,34 @@ def load_file(self, fn, index_mapping_dir: Optional[str] = None):
idx_fn = _index_fn(fn, index_mapping_dir)

# create data map
mdata = np.memmap(fn, dtype=np.uint8, mode='r')
mdata = np.memmap(fn, dtype=np.uint8, mode="r")

if _index_file_exists(idx_fn):
# load index file into memory map
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode='r')
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode="r")
# test for header
if len(midx) < self._header_lines:
raise RuntimeError(f"Missing header, expected {self._header_lines} header lines")

# load meta info
idx_info_dict = pickle.load(open(idx_fn + ".info", 'rb'))
idx_info_dict = pickle.load(open(idx_fn + ".info", "rb"))
# test for mismatch in expected newline_int
if 'newline_int' in idx_info_dict:
newline_int = idx_info_dict['newline_int']
if "newline_int" in idx_info_dict:
newline_int = idx_info_dict["newline_int"]
if self._newline_int != newline_int:
logging.warning(
f"Mismatch in newline_int, expected = {self._newline_int} but loaded {newline_int}"
)

# test for version mismatch (useful to force recreation of index files)
idx_version = idx_info_dict.get('version', '0.0')
idx_version = idx_info_dict.get("version", "0.0")
if __idx_version__ != idx_version:
raise RuntimeError(
f"Version mismatch: Please delete existing '.{__idx_suffix__}' files. Expected version = {__idx_version__}, but file version = {idx_version}. File path = {idx_fn}"
)
else:
raise ValueError(
f'Memory Map for {fn} is not found, missing one or more of files: {idx_fn}.{{.npy,.info}}'
f"Memory Map for {fn} is not found, missing one or more of files: {idx_fn}.{{.npy,.info}}"
)

return (mdata, midx)
Expand All @@ -271,7 +297,7 @@ def __init__(
tokenizer: Optional[Type["TokenizerSpec"]] = None,
sort_dataset_paths: Optional[bool] = True,
data_col=1,
data_sep=',',
data_sep=",",
index_mapping_dir: Optional[str] = None,
):
"""
Expand Down Expand Up @@ -424,7 +450,7 @@ def _build_memmap_index_files(newline_int, build_index_fn, fn, index_mapping_dir


def build_index_files(
dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata, index_mapping_dir: str = None
dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata, index_mapping_dir: str = None,
):
"""Auxiliary method to build multiple index files"""
if len(dataset_paths) < 1:
Expand All @@ -438,10 +464,10 @@ def build_index_files(
start_time = time.time()
with mp.Pool(workers) as p:
build_status = p.map(
partial(_build_memmap_index_files, newline_int, build_index_fn, index_mapping_dir=index_mapping_dir),
partial(_build_memmap_index_files, newline_int, build_index_fn, index_mapping_dir=index_mapping_dir,),
dataset_paths,
)

logging.info(
f'Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}'
f"Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
)
12 changes: 8 additions & 4 deletions tests/collections/nlp/test_mem_map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ def jsonl_file(tmp_path):
file_path = tmp_path / "data.jsonl"

# Generate data to write to the JSONL file
data = [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}, {"name": "Bob", "age": 35}]
data = [
{"name": "John", "age": 30},
{"name": "Jane", "age": 25},
{"name": "Bob", "age": 35},
]

# Write data to the JSONL file
with open(file_path, mode="w") as file:
for item in data:
json.dump(item, file)
file.write('\n')
file.write("\n")

# Provide the file path to the test function
yield str(file_path)
Expand Down Expand Up @@ -81,12 +85,12 @@ def test_csv_mem_map_dataset(csv_file):


@pytest.mark.parametrize(
"dataset_class", [text_memmap_dataset.JSONLMemMapDataset, text_memmap_dataset.CSVMemMapDataset]
"dataset_class", [text_memmap_dataset.JSONLMemMapDataset, text_memmap_dataset.CSVMemMapDataset],
)
@pytest.mark.parametrize("use_alternative_index_mapping_dir", [True, False])
@pytest.mark.parametrize("relative_index_fn", [True, False])
def test_mem_map_dataset_index_mapping_dir(
tmp_path, dataset_class, jsonl_file, use_alternative_index_mapping_dir, relative_index_fn
tmp_path, dataset_class, jsonl_file, use_alternative_index_mapping_dir, relative_index_fn,
):
"""Test for index_mapping_dir."""
if relative_index_fn:
Expand Down