Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 26, 2023
1 parent ba96a7b commit 298a069
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 56 deletions.
47 changes: 11 additions & 36 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import os
import pickle
import time
from typing import Callable, List, Optional, Type
from functools import partial
from typing import Callable, List, Optional, Type

import numpy as np
import torch
Expand Down Expand Up @@ -74,9 +74,7 @@ def __init__(
header_lines: Optional[int] = 0,
workers: Optional[int] = None,
tokenizer: Optional[Type["TokenizerSpec"]] = None,
build_index_fn: Optional[
Callable[[str, Optional[int]], bool]
] = _build_index_from_memdata,
build_index_fn: Optional[Callable[[str, Optional[int]], bool]] = _build_index_from_memdata,
sort_dataset_paths: Optional[bool] = True,
index_mapping_dir: Optional[str] = None,
):
Expand Down Expand Up @@ -117,9 +115,7 @@ def __init__(

logging.info(f"Building data files")
# load all files into memmap
is_distributed = (
torch.distributed.is_available() and torch.distributed.is_initialized()
)
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()

if not is_distributed or (is_distributed and torch.distributed.get_rank() == 0):
# Create index files on global rank 0.
Expand Down Expand Up @@ -161,17 +157,13 @@ def __init__(

logging.info(f"Loading data files")
start_time = time.time()
mdata_midx_list = [
self.load_file(fn, index_mapping_dir) for fn in self._files_list
]
mdata_midx_list = [self.load_file(fn, index_mapping_dir) for fn in self._files_list]
logging.info(
f"Time loading {len(mdata_midx_list)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}"
)

logging.info("Computing global indices")
midx_bins = np.cumsum(
[(len(midx) - header_lines) for _, midx in mdata_midx_list]
)
midx_bins = np.cumsum([(len(midx) - header_lines) for _, midx in mdata_midx_list])

self.midx_bins = midx_bins
self.mdata_midx_list = mdata_midx_list
Expand All @@ -192,9 +184,7 @@ def __getitem__(self, idx):
Return a string from binary memmap
"""
if (idx >= len(self)) or (idx < 0):
raise IndexError(
f"Index {idx} if out of dataset range with {len(self)} samples"
)
raise IndexError(f"Index {idx} if out of dataset range with {len(self)} samples")

# Identify the file containing the record
file_id = np.digitize(idx, self.midx_bins, right=False)
Expand Down Expand Up @@ -225,9 +215,7 @@ def __getitem__(self, idx):
logging.error(
f"Error while building data from text, possible issue with sample expected format (see offending sample below): {e}"
)
logging.error(
f"sample: {sample}, file_id: {file_id}, file_idx: {file_idx}, i: {i}, j: {j}"
)
logging.error(f"sample: {sample}, file_id: {file_id}, file_idx: {file_idx}, i: {i}, j: {j}")
raise e

return data
Expand Down Expand Up @@ -269,9 +257,7 @@ def load_file(self, fn, index_mapping_dir: Optional[str] = None):
midx = np.load(idx_fn + ".npy", allow_pickle=True, mmap_mode="r")
# test for header
if len(midx) < self._header_lines:
raise RuntimeError(
f"Missing header, expected {self._header_lines} header lines"
)
raise RuntimeError(f"Missing header, expected {self._header_lines} header lines")

# load meta info
idx_info_dict = pickle.load(open(idx_fn + ".info", "rb"))
Expand Down Expand Up @@ -449,9 +435,7 @@ def _build_memmap_index_files(newline_int, build_index_fn, fn, index_mapping_dir
# validate midx
midx = np.asarray(midx)
if not np.issubdtype(midx.dtype, np.integer):
raise TypeError(
f"midx must be an integer array, but got type = {midx.dtype}"
)
raise TypeError(f"midx must be an integer array, but got type = {midx.dtype}")

# create e metadata file
data = dict(newline_int=newline_int, version=__idx_version__)
Expand All @@ -466,11 +450,7 @@ def _build_memmap_index_files(newline_int, build_index_fn, fn, index_mapping_dir


def build_index_files(
dataset_paths,
newline_int,
workers=None,
build_index_fn=_build_index_from_memdata,
index_mapping_dir: str = None,
dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata, index_mapping_dir: str = None,
):
"""Auxiliary method to build multiple index files"""
if len(dataset_paths) < 1:
Expand All @@ -484,12 +464,7 @@ def build_index_files(
start_time = time.time()
with mp.Pool(workers) as p:
build_status = p.map(
partial(
_build_memmap_index_files,
newline_int,
build_index_fn,
index_mapping_dir=index_mapping_dir,
),
partial(_build_memmap_index_files, newline_int, build_index_fn, index_mapping_dir=index_mapping_dir,),
dataset_paths,
)

Expand Down
28 changes: 8 additions & 20 deletions tests/collections/nlp/test_mem_map_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

import csv
import json
import os

import pytest
import json

from nemo.collections.nlp.data.language_modeling import text_memmap_dataset

Expand Down Expand Up @@ -68,9 +69,7 @@ def csv_file(tmp_path):
def test_jsonl_mem_map_dataset(jsonl_file):
"""Test for JSONL memory-mapped datasets."""

indexed_dataset = text_memmap_dataset.JSONLMemMapDataset(
dataset_paths=[jsonl_file], header_lines=0
)
indexed_dataset = text_memmap_dataset.JSONLMemMapDataset(dataset_paths=[jsonl_file], header_lines=0)
assert indexed_dataset[0] == {"name": "John", "age": 30}
assert indexed_dataset[1] == {"name": "Jane", "age": 25}
assert indexed_dataset[2] == {"name": "Bob", "age": 35}
Expand All @@ -79,26 +78,19 @@ def test_jsonl_mem_map_dataset(jsonl_file):
def test_csv_mem_map_dataset(csv_file):
"""Test for CSV memory-mapped datasets."""

indexed_dataset = text_memmap_dataset.CSVMemMapDataset(
dataset_paths=[csv_file], data_col=1, header_lines=1
)
indexed_dataset = text_memmap_dataset.CSVMemMapDataset(dataset_paths=[csv_file], data_col=1, header_lines=1)
assert indexed_dataset[0].strip() == "John"
assert indexed_dataset[1].strip() == "Jane"
assert indexed_dataset[2].strip() == "Bob"


@pytest.mark.parametrize(
"dataset_class",
[text_memmap_dataset.JSONLMemMapDataset, text_memmap_dataset.CSVMemMapDataset],
"dataset_class", [text_memmap_dataset.JSONLMemMapDataset, text_memmap_dataset.CSVMemMapDataset],
)
@pytest.mark.parametrize("use_alternative_index_mapping_dir", [True, False])
@pytest.mark.parametrize("relative_index_fn", [True, False])
def test_mem_map_dataset_index_mapping_dir(
tmp_path,
dataset_class,
jsonl_file,
use_alternative_index_mapping_dir,
relative_index_fn,
tmp_path, dataset_class, jsonl_file, use_alternative_index_mapping_dir, relative_index_fn,
):
"""Test for index_mapping_dir."""

Expand All @@ -109,9 +101,7 @@ def test_mem_map_dataset_index_mapping_dir(
else:
jsonl_file = os.path.abspath(jsonl_file)
dataset_class(
dataset_paths=[jsonl_file],
header_lines=0,
index_mapping_dir=str(index_mapping_dir),
dataset_paths=[jsonl_file], header_lines=0, index_mapping_dir=str(index_mapping_dir),
)
# Index files should not be created in default location.
assert not os.path.isfile(f"{jsonl_file}.idx.npy")
Expand All @@ -121,8 +111,6 @@ def test_mem_map_dataset_index_mapping_dir(
assert os.path.isfile(f"{idx_fn}.npy")
assert os.path.isfile(f"{idx_fn}.info")
else:
text_memmap_dataset.JSONLMemMapDataset(
dataset_paths=[jsonl_file], header_lines=0
)
text_memmap_dataset.JSONLMemMapDataset(dataset_paths=[jsonl_file], header_lines=0)
assert os.path.isfile(f"{jsonl_file}.idx.npy")
assert os.path.isfile(f"{jsonl_file}.idx.info")

0 comments on commit 298a069

Please sign in to comment.