Skip to content

Commit

Permalink
Sharded manifests for tarred datasets (#6395)
Browse files Browse the repository at this point in the history
* testing sharded manifests

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* compatibility

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* proper fixes

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* adding flag tot convert_to_tarred_audio_dataset

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* shard_manifests conf param

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* propagating the shard_manifests param

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* propagating the shard_manifests param

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* distributed checks

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* typo

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* typo

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* fixes

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* fixes

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* fixes

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* fixes

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* fixes

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* fixes

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes based on PR comments and tests

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes to convert_to_tarred_audio_dataset.py

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* reversing manifest shards flag

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tests

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* excluding manifests from webdataset url expansion

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* expand manifest paths before attempting to cache from datastore

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

* explicit use of UTF-8 for manifest i/o

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>

---------

Signed-off-by: Dima Rekesh <bmwshop@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
bmwshop and pre-commit-ci[bot] authored Apr 18, 2023
1 parent 536ee62 commit ceb539f
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 28 deletions.
6 changes: 3 additions & 3 deletions nemo/collections/asr/data/audio_to_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
import webdataset as wd

from nemo.collections.asr.data.audio_to_text import cache_datastore_manifests, expand_audio_filepaths
from nemo.collections.asr.data.audio_to_text import cache_datastore_manifests, expand_sharded_filepaths
from nemo.collections.asr.parts.preprocessing.segment import available_formats as valid_sf_formats
from nemo.collections.common.parts.preprocessing import collections
from nemo.core.classes import Dataset, IterableDataset
Expand Down Expand Up @@ -560,8 +560,8 @@ def __init__(
for idx in range(len(self.labels[:5])):
logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))

audio_tar_filepaths = expand_audio_filepaths(
audio_tar_filepaths=audio_tar_filepaths,
audio_tar_filepaths = expand_sharded_filepaths(
sharded_filepaths=audio_tar_filepaths,
shard_strategy=shard_strategy,
world_size=world_size,
global_rank=global_rank,
Expand Down
110 changes: 91 additions & 19 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,47 +171,48 @@ def process_text_by_sample(self, sample: collections.ASRAudioText.OUTPUT_TYPE) -
return t, tl


def expand_audio_filepaths(audio_tar_filepaths, shard_strategy: str, world_size: int, global_rank: int):
def expand_sharded_filepaths(sharded_filepaths, shard_strategy: str, world_size: int, global_rank: int):
valid_shard_strategies = ['scatter', 'replicate']
if shard_strategy not in valid_shard_strategies:
raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}")

if isinstance(audio_tar_filepaths, str):
if isinstance(sharded_filepaths, str):
# Replace '(' and '[' with '{'
brace_keys_open = ['(', '[', '<', '_OP_']
for bkey in brace_keys_open:
if bkey in audio_tar_filepaths:
audio_tar_filepaths = audio_tar_filepaths.replace(bkey, "{")
if bkey in sharded_filepaths:
sharded_filepaths = sharded_filepaths.replace(bkey, "{")

# Replace ')' and ']' with '}'
brace_keys_close = [')', ']', '>', '_CL_']
for bkey in brace_keys_close:
if bkey in audio_tar_filepaths:
audio_tar_filepaths = audio_tar_filepaths.replace(bkey, "}")
if bkey in sharded_filepaths:
sharded_filepaths = sharded_filepaths.replace(bkey, "}")

if isinstance(audio_tar_filepaths, str):
if isinstance(sharded_filepaths, str):
# Brace expand
audio_tar_filepaths = list(braceexpand.braceexpand(audio_tar_filepaths))
sharded_filepaths = list(braceexpand.braceexpand(sharded_filepaths))

# Expand store paths into WebDataset URLs
audio_tar_filepaths = [
datastore_path_to_webdataset_url(p) if is_datastore_path(p) else p for p in audio_tar_filepaths
sharded_filepaths = [
datastore_path_to_webdataset_url(p) if is_datastore_path(p) and is_tarred_path(p) else p
for p in sharded_filepaths
]

# Check for distributed and partition shards accordingly
if world_size > 1:
if shard_strategy == 'scatter':
logging.info("All tarred dataset shards will be scattered evenly across all nodes.")

if len(audio_tar_filepaths) % world_size != 0:
if len(sharded_filepaths) % world_size != 0:
logging.warning(
f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
f"Number of shards in tarred dataset ({len(sharded_filepaths)}) is not divisible "
f"by number of distributed workers ({world_size})."
)

begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank
end_idx = begin_idx + len(audio_tar_filepaths) // world_size
audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
begin_idx = (len(sharded_filepaths) // world_size) * global_rank
end_idx = begin_idx + len(sharded_filepaths) // world_size
sharded_filepaths = sharded_filepaths[begin_idx:end_idx]
logging.info(
"Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx
)
Expand All @@ -221,7 +222,7 @@ def expand_audio_filepaths(audio_tar_filepaths, shard_strategy: str, world_size:
else:
raise ValueError(f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}")

return audio_tar_filepaths
return sharded_filepaths


def cache_datastore_manifests(
Expand Down Expand Up @@ -345,6 +346,47 @@ def cache_data(manifest_filepaths, cache_audio, num_workers, max_num_workers):
)


"""Optionally expand / shard the list of manifests
This is made to use the same notation as the sharded audio files
Args:
manifest_filepaths: list of manifest files (the sharded notation)
shard_strategy: scatter or replicate (scatter by default)
shard_manifests: bool, if False, no sharding / manifest filepath expansion will be attempted
global_rank: int, the rank of this worker
world_size: int, total number of workers
"""


def shard_manifests_if_needed(
manifest_filepaths: Union[str, List[str]],
shard_strategy: str,
shard_manifests: bool,
global_rank: int,
world_size: int,
):
if shard_manifests:
if not torch.distributed.is_available():
logging.warning("Not running in torch.distributed mode. Manifest sharding not available")
return manifest_filepaths

if not torch.distributed.is_initialized():
logging.warning(
'Manifest sharding was requested but torch.distributed is not initialized '
'Did you intend to set the defer_setup flag?'
)
return manifest_filepaths

manifest_filepaths = expand_sharded_filepaths(
sharded_filepaths=manifest_filepaths,
shard_strategy=shard_strategy,
world_size=world_size,
global_rank=global_rank,
)

return manifest_filepaths


class _AudioTextDataset(Dataset):
"""
Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds).
Expand Down Expand Up @@ -748,6 +790,7 @@ class _TarredAudioToTextDataset(IterableDataset):
occasions (when the number of shards is not divisible with ``world_size``), will not sample
the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
or test datasets.
shard_manifests (bool): Whether or not to try / shard manifests. Defaults to False.
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
return_sample_id (bool): whether to return the sample_id as a part of each sample
Expand All @@ -769,10 +812,22 @@ def __init__(
eos_id: Optional[int] = None,
pad_id: int = 0,
shard_strategy: str = "scatter",
shard_manifests: bool = False,
global_rank: int = 0,
world_size: int = 0,
return_sample_id: bool = False,
):
self.shard_manifests = shard_manifests

# Shard manifests if necessary and possible and then expand the paths
manifest_filepath = shard_manifests_if_needed(
shard_manifests=shard_manifests,
shard_strategy=shard_strategy,
manifest_filepaths=manifest_filepath,
world_size=world_size,
global_rank=global_rank,
)

# If necessary, cache manifests from object store
cache_datastore_manifests(manifest_filepaths=manifest_filepath)

Expand All @@ -788,15 +843,17 @@ def __init__(
index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID
)

self.len = self._compute_len()

self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
self.trim = trim
self.eos_id = eos_id
self.bos_id = bos_id
self.pad_id = pad_id
self.return_sample_id = return_sample_id

audio_tar_filepaths = expand_audio_filepaths(
audio_tar_filepaths=audio_tar_filepaths,
audio_tar_filepaths = expand_sharded_filepaths(
sharded_filepaths=audio_tar_filepaths,
shard_strategy=shard_strategy,
world_size=world_size,
global_rank=global_rank,
Expand Down Expand Up @@ -928,8 +985,19 @@ def get_manifest_sample(self, sample_id):
def __iter__(self):
return self._dataset.__iter__()

def _compute_len(self):
if self.shard_manifests and torch.distributed.is_available() and torch.distributed.is_initialized():
my_len = torch.tensor(len(self.manifest_processor.collection), dtype=torch.int32).cuda()
torch.distributed.all_reduce(my_len)
my_len = my_len.int()
logging.info(f'Sharded manifests: Total length: {my_len}')
else:
my_len = len(self.manifest_processor.collection)

return my_len

def __len__(self):
return len(self.manifest_processor.collection)
return self.len


class TarredAudioToCharDataset(_TarredAudioToTextDataset):
Expand Down Expand Up @@ -1042,6 +1110,7 @@ def __init__(
parser: Optional[str] = 'en',
pad_id: int = 0,
shard_strategy: str = "scatter",
shard_manifests: bool = False,
global_rank: int = 0,
world_size: int = 0,
return_sample_id: bool = False,
Expand All @@ -1067,6 +1136,7 @@ def __init__(
eos_id=eos_id,
pad_id=pad_id,
shard_strategy=shard_strategy,
shard_manifests=shard_manifests,
global_rank=global_rank,
world_size=world_size,
return_sample_id=return_sample_id,
Expand Down Expand Up @@ -1167,6 +1237,7 @@ def __init__(
trim: bool = False,
use_start_end_token: bool = True,
shard_strategy: str = "scatter",
shard_manifests: bool = False,
global_rank: int = 0,
world_size: int = 0,
return_sample_id: bool = False,
Expand Down Expand Up @@ -1219,6 +1290,7 @@ def __call__(self, *args):
eos_id=eos_id,
pad_id=pad_id,
shard_strategy=shard_strategy,
shard_manifests=shard_manifests,
global_rank=global_rank,
world_size=world_size,
return_sample_id=return_sample_id,
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/data/audio_to_text_dali.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
from omegaconf import DictConfig

from nemo.collections.asr.data.audio_to_text import ASRManifestProcessor, expand_audio_filepaths
from nemo.collections.asr.data.audio_to_text import ASRManifestProcessor, expand_sharded_filepaths
from nemo.collections.common.parts.preprocessing import parsers
from nemo.utils import logging, model_utils

Expand Down Expand Up @@ -345,10 +345,10 @@ def __init__(
self.is_tarred_dataset = False

elif audio_tar_filepaths is not None and audio_tar_index_filepaths is not None:
audio_tar_filepaths = expand_audio_filepaths(
audio_tar_filepaths = expand_sharded_filepaths(
audio_tar_filepaths, shard_strategy=shard_strategy, world_size=world_size, global_rank=global_rank
)
audio_tar_index_filepaths = expand_audio_filepaths(
audio_tar_index_filepaths = expand_sharded_filepaths(
audio_tar_index_filepaths,
shard_strategy=shard_strategy,
world_size=world_size,
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ def get_tarred_dataset(
):
if len(tarred_audio_filepath) == 1:
tarred_audio_filepath = tarred_audio_filepath[0]
if len(manifest_filepath) == 1:
manifest_filepath = manifest_filepath[0]

if tokenizer is None:
dataset = audio_to_text.TarredAudioToCharDataset(
audio_tar_filepaths=tarred_audio_filepath,
Expand All @@ -363,6 +366,7 @@ def get_tarred_dataset(
trim=config.get('trim_silence', False),
parser=config.get('parser', 'en'),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
shard_manifests=config.get('shard_manifests', False),
global_rank=global_rank,
world_size=world_size,
return_sample_id=config.get('return_sample_id', False),
Expand All @@ -381,6 +385,7 @@ def get_tarred_dataset(
trim=config.get('trim_silence', False),
use_start_end_token=config.get('use_start_end_token', True),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
shard_manifests=config.get('shard_manifests', False),
global_rank=global_rank,
world_size=world_size,
return_sample_id=config.get('return_sample_id', False),
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/asr/models/configs/asr_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig):
is_tarred: bool = False
tarred_audio_filepaths: Optional[Any] = None
tarred_shard_strategy: str = "scatter"
shard_manifests: bool = False
shuffle_n: int = 0

# Optional
Expand Down
6 changes: 6 additions & 0 deletions nemo/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def is_datastore_path(path) -> bool:
return path.startswith('ais://')


def is_tarred_path(path) -> bool:
"""Check if a path is for a tarred file.
"""
return path.endswith('.tar')


def is_datastore_cache_shared() -> bool:
"""Check if store cache is shared.
"""
Expand Down
Loading

0 comments on commit ceb539f

Please sign in to comment.