diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 1163af5e4e6..d44826f36ea 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -100,6 +100,8 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table. [[autodoc]] datasets.interleave_datasets +[[autodoc]] datasets.distributed.split_dataset_by_node + [[autodoc]] datasets.enable_caching [[autodoc]] datasets.disable_caching diff --git a/docs/source/use_with_pytorch.mdx b/docs/source/use_with_pytorch.mdx index 5a812b3c1e8..3b2b71463bb 100644 --- a/docs/source/use_with_pytorch.mdx +++ b/docs/source/use_with_pytorch.mdx @@ -201,7 +201,7 @@ You must use a `BatchSampler` if you want the transform to be given full batches ### Stream data Loading a dataset in streaming mode allows one to iterate over the dataset without downloading it on disk. -An iterable dataset from `datasets` inherits from `torch.utils.data.IterableDataset` so you can pass it to a `DataLoader`: +An iterable dataset from `datasets` inherits from `torch.utils.data.IterableDataset` so you can pass it to a `torch.utils.data.DataLoader`: ```py >>> import numpy as np @@ -222,4 +222,30 @@ If the dataset is split in several shards (i.e. if the dataset consists of multi >>> dataloader = DataLoader(ds, batch_size=32, num_workers=4) ``` -In this case each worker will be given a subset of the list of shards to stream from. +In this case each worker is given a subset of the list of shards to stream from. + +### Distributed + +To split your dataset across your training nodes, you can use [`datasets.distributed.split_dataset_by_node`]: + +```python +import os +from datasets.distributed import split_dataset_by_node + +ds = split_dataset_by_node(ds, rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"])) +``` + +This works for both map-style datasets and iterable datasets. +The dataset is split for the node at rank `rank` in a pool of nodes of size `world_size`. + +For map-style datasets: + +Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset. + +For iterable datasets: + +If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`), +then the shards are evenly assigned across the nodes, which is the most optimized. +Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. + +This can also be combined with a `torch.utils.data.DataLoader` if you want each node to use multiple workers to load the data. diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 20d04faafe4..8e245f8986e 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -5628,6 +5628,26 @@ def iter_random_indices(): return concatenated_datasets.select(indices, **kwargs) +def _split_by_node_map_style_dataset(dataset: Dataset, rank: int, world_size: int) -> Dataset: + """ + Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`. + Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset. + To maximize data loading throughput, chunks are made of contiguous data on disk if possible. + + Args: + dataset ([`Dataset`]): + The dataset to split by node. + rank (`int`): + Rank of the current node. + world_size (`int`): + Total number of nodes. + + Returns: + [`Dataset`]: The dataset to be used on the node at rank `rank`. + """ + return dataset.shard(num_shards=world_size, index=rank, contiguous=True) + + # This is outside Dataset.filter as it needs to be picklable for multiprocessing diff --git a/src/datasets/combine.py b/src/datasets/combine.py index f8a05de694a..af79dba0a5f 100644 --- a/src/datasets/combine.py +++ b/src/datasets/combine.py @@ -10,7 +10,7 @@ logger = logging.get_logger(__name__) -DatasetType = TypeVar("DatasetType", "Dataset", "IterableDataset") +DatasetType = TypeVar("DatasetType", Dataset, IterableDataset) def interleave_datasets( @@ -137,11 +137,11 @@ def interleave_datasets( def concatenate_datasets( - dsets: List[Dataset], + dsets: List[DatasetType], info: Optional[DatasetInfo] = None, split: Optional[NamedSplit] = None, axis: int = 0, -): +) -> DatasetType: """ Converts a list of [`Dataset`] with the same schema into a single [`Dataset`]. diff --git a/src/datasets/distributed.py b/src/datasets/distributed.py new file mode 100644 index 00000000000..e036fabaf2c --- /dev/null +++ b/src/datasets/distributed.py @@ -0,0 +1,39 @@ +from typing import TypeVar + +from .arrow_dataset import Dataset, _split_by_node_map_style_dataset +from .iterable_dataset import IterableDataset, _split_by_node_iterable_dataset + + +DatasetType = TypeVar("DatasetType", Dataset, IterableDataset) + + +def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType: + """ + Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`. + + For map-style datasets: + + Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset. + To maximize data loading throughput, chunks are made of contiguous data on disk if possible. + + For iterable datasets: + + If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`), + then the shards are evenly assigned across the nodes, which is the most optimized. + Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. + + Args: + dataset ([`Dataset`] or [`IterableDataset`]): + The dataset to split by node. + rank (`int`): + Rank of the current node. + world_size (`int`): + Total number of nodes. + + Returns: + [`Dataset`] or [`IterableDataset`]: The dataset to be used on the node at rank `rank`. + """ + if isinstance(dataset, Dataset): + return _split_by_node_map_style_dataset(dataset, rank=rank, world_size=world_size) + else: + return _split_by_node_iterable_dataset(dataset, rank=rank, world_size=world_size) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 09becb7f3f3..1c68e2f5f62 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -20,7 +20,7 @@ from .splits import NamedSplit from .table import table_cast from .utils.logging import get_logger -from .utils.sharding import _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs +from .utils.sharding import _merge_gen_kwargs, _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs logger = get_logger(__name__) @@ -95,7 +95,7 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamples """ raise NotImplementedError(f"{type(self)} doesn't implement shuffle_data_sources yet") - def shard_data_sources(self, shard_idx: int) -> "_BaseExamplesIterable": + def shard_data_sources(self, shard_indices: List[int]) -> "_BaseExamplesIterable": """Either keep only the requested shard, or propagate the request to the underlying iterable.""" raise NotImplementedError(f"{type(self)} doesn't implement shard_data_sources yet") @@ -113,19 +113,20 @@ def __iter__(self): yield from self.generate_examples_fn(**self.kwargs) def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesIterable": - return ShardShuffledExamplesIterable(self.generate_examples_fn, self.kwargs, generator) + return ShuffledDataSourcesExamplesIterable(self.generate_examples_fn, self.kwargs, generator) - def shard_data_sources(self, shard_idx: int) -> "ExamplesIterable": + def shard_data_sources(self, shard_indices: List[int]) -> "ExamplesIterable": """Keep only the requested shard.""" - kwargs_with_requested_data_source = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards)[shard_idx] - return ExamplesIterable(self.generate_examples_fn, kwargs_with_requested_data_source) + gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards) + requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) + return ExamplesIterable(self.generate_examples_fn, requested_gen_kwargs) @property def n_shards(self) -> int: return _number_of_shards_in_gen_kwargs(self.kwargs) -class ShardShuffledExamplesIterable(ExamplesIterable): +class ShuffledDataSourcesExamplesIterable(ExamplesIterable): def __init__(self, generate_examples_fn: Callable, kwargs: dict, generator: np.random.Generator): super().__init__(generate_examples_fn, kwargs) self.generator = deepcopy(generator) @@ -136,14 +137,43 @@ def __iter__(self): kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) yield from self.generate_examples_fn(**kwargs_with_shuffled_shards) - def shard_data_sources(self, shard_idx: int) -> "ExamplesIterable": + def shard_data_sources(self, shard_indices: List[int]) -> "ExamplesIterable": """Keep only the requested shard.""" rng = deepcopy(self.generator) kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) - kwargs_with_requested_data_source = _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards)[ - shard_idx - ] - return ExamplesIterable(self.generate_examples_fn, kwargs_with_requested_data_source) + return ExamplesIterable(self.generate_examples_fn, kwargs_with_shuffled_shards).shard_data_sources( + shard_indices + ) + + +class StepExamplesIterable(_BaseExamplesIterable): + def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int): + self.ex_iterable = ex_iterable + self.step = step + self.offset = offset + + def __iter__(self): + ex_iterator = iter(self.ex_iterable) + while True: + batch = list(islice(ex_iterator, self.step)) + if len(batch) > self.offset: + yield batch[self.offset] + else: + break + + def shuffle_data_sources(self, generator: np.random.Generator) -> "StepExamplesIterable": + return StepExamplesIterable( + self.ex_iterable.shuffle_data_sources(generator), step=self.step, offset=self.offset + ) + + def shard_data_sources(self, shard_indices: List[int]) -> "StepExamplesIterable": + return StepExamplesIterable( + self.ex_iterable.shard_data_sources(shard_indices), step=self.step, offset=self.offset + ) + + @property + def n_shards(self) -> int: + return self.ex_iterable.n_shards class CyclingMultiSourcesExamplesIterable(_BaseExamplesIterable): @@ -723,6 +753,12 @@ class ShufflingConfig: generator: np.random.Generator +@dataclass +class DistributedConfig: + rank: int + world_size: int + + def _maybe_add_torch_iterable_dataset_parent_class(cls): """Add torch.utils.data.IterableDataset as a parent class if 'torch' is available""" if config.TORCH_AVAILABLE: @@ -742,6 +778,7 @@ def __init__( split: Optional[NamedSplit] = None, format_type: Optional[str] = None, shuffling: Optional[ShufflingConfig] = None, + distributed: Optional[DistributedConfig] = None, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None, ): info = info.copy() if info is not None else DatasetInfo() @@ -750,6 +787,7 @@ def __init__( self._ex_iterable = ex_iterable self._format_type = format_type self._shuffling = shuffling + self._distributed = distributed self._epoch = 0 self._token_per_repo_id: Dict[str, Union[str, bool, None]] = token_per_repo_id or {} _maybe_add_torch_iterable_dataset_parent_class(self.__class__) @@ -763,7 +801,7 @@ def __setstate__(self, d): _maybe_add_torch_iterable_dataset_parent_class(self.__class__) def _head(self, n=5): - return _examples_to_batch([x for key, x in islice(self._iter(), n)]) + return _examples_to_batch(list(self.take(n))) def _effective_generator(self): if self._shuffling and self._epoch == 0: @@ -778,72 +816,105 @@ def _effective_generator(self): @property def n_shards(self) -> int: + if self._distributed and self._ex_iterable.n_shards % self._distributed.world_size == 0: + return self._ex_iterable.n_shards // self._distributed.world_size return self._ex_iterable.n_shards - def _iter(self): - if self._shuffling: - ex_iterable = self._ex_iterable.shuffle_data_sources(self._effective_generator()) + def _iter_pytorch(self, ex_iterable: _BaseExamplesIterable): + # fix for fsspec when using multiprocess + _reset_fsspec_lock() + # check if there aren't too many workers + import torch.utils.data + + worker_info = torch.utils.data.get_worker_info() + if self._is_main_process() and ex_iterable.n_shards < worker_info.num_workers: + logger.warning( + f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={ex_iterable.n_shards}). " + f"Stopping {worker_info.num_workers - ex_iterable.n_shards} dataloader workers." + ) + logger.warning( + f"To parallelize data loading, we give each process some shards (or data sources) to process. " + f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={ex_iterable.n_shards}. " + f"To enable more parallelism, please split the dataset in more files than {ex_iterable.n_shards}." + ) + # split workload + shards_indices = list(range(worker_info.id, ex_iterable.n_shards, worker_info.num_workers)) + _log_prefix = f"node#{self._distributed.rank} " if self._distributed else "" + if shards_indices: + logger.debug( + f"{_log_prefix}dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{ex_iterable.n_shards} shards." + ) + for key, example in ex_iterable.shard_data_sources(shards_indices): + if self.features: + yield _apply_feature_types_on_example( + example, self.features, token_per_repo_id=self._token_per_repo_id + ) + else: + yield example + logger.debug( + f"{_log_prefix}dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{ex_iterable.n_shards} shards." + ) else: - ex_iterable = self._ex_iterable - yield from ex_iterable + logger.debug( + f"{_log_prefix}dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({ex_iterable.n_shards}<{worker_info.num_workers})." + ) - def _iter_shard(self, shard_idx: int): + def _is_main_process(self): + if self._distributed and self._distributed.rank > 0: + return False + if "torch" in sys.modules: + import torch.utils.data + + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None and worker_info.id > 0: + return False + return True + + def _prepare_ex_iterable_for_iteration(self) -> _BaseExamplesIterable: if self._shuffling: ex_iterable = self._ex_iterable.shuffle_data_sources(self._effective_generator()) else: ex_iterable = self._ex_iterable - yield from ex_iterable.shard_data_sources(shard_idx) - def _iter_pytorch(self, worker_info): - # fix for fsspec when using multprocess - _reset_fsspec_lock() - if worker_info is None: # single-process data loading, return the full iterator - yield from IterableDataset.__iter__(self) - else: # in a worker process - # check if there aren't too many workers - if worker_info.id == 0 and self.n_shards < worker_info.num_workers: - logger.warning( - f"Too many dataloader workers: {worker_info.num_workers} (max is dataset.n_shards={self.n_shards}). " - f"Stopping dataloader workers [{self.n_shards}...{worker_info.num_workers -1}]." - ) - logger.warning( - f"To parallelize data loading, we give each process some shards (or data sources) to process. " - f"Therefore it's unnecessary to have a number of workers greater than dataset.n_shards={self.n_shards}." - f"To enable more parallelism, please split the dataset in more files than {self.n_shards}." - ) - # split workload - shards_indices = list(range(worker_info.id, self.n_shards, worker_info.num_workers)) - if shards_indices: - logger.debug( - f"dataloader worker#{worker_info.id}, ': Starting to iterate over {len(shards_indices)}/{self.n_shards} shards." - ) - for shard_idx in shards_indices: - for key, example in self._iter_shard(shard_idx): - if self.features: - yield _apply_feature_types_on_example( - example, self.features, token_per_repo_id=self._token_per_repo_id - ) - else: - yield example - logger.debug( - f"dataloader worker#{worker_info.id}, ': Finished iterating over {len(shards_indices)}/{self.n_shards} shards." - ) + if self._distributed: + rank = self._distributed.rank + world_size = self._distributed.world_size + if ex_iterable.n_shards % world_size == 0: + if self._is_main_process(): + n_shards_per_node = ex_iterable.n_shards // world_size + plural = "s" if n_shards_per_node > 1 else "" + logger.warning( + f"Assigning {n_shards_per_node} shard{plural} (or data source{plural}) of the dataset to each node." + ) + shards_indices = list(range(rank, ex_iterable.n_shards, world_size)) + ex_iterable = ex_iterable.shard_data_sources(shards_indices) else: - logger.debug( - f"dataloader worker#{worker_info.id}, ': Stopping... Number of dataset shards < num_workers ({self.n_shards}<{worker_info.num_workers})." - ) + if self._is_main_process(): + logger.warning( + f"Assigning 1 out of {world_size} examples of the dataset to each node. The others are skipped during the iteration." + ) + logger.info( + f"It is more optimized to distribute the dataset shards (or data sources) across nodes. " + f"You can do that by using a dataset with number of shards that is a factor of world_size={world_size}. " + f"The current dataset has {ex_iterable.n_shards} which is not a factor of {world_size}" + ) + ex_iterable = StepExamplesIterable(ex_iterable, step=world_size, offset=rank) + + return ex_iterable def __iter__(self): + ex_iterable = self._prepare_ex_iterable_for_iteration() + if "torch" in sys.modules: import torch.utils.data worker_info = torch.utils.data.get_worker_info() if isinstance(self, torch.utils.data.IterableDataset) and worker_info is not None: # We're a torch.utils.data.IterableDataset in a PyTorch worker process - yield from self._iter_pytorch(worker_info) + yield from self._iter_pytorch(ex_iterable) return - for key, example in self._iter(): + for key, example in ex_iterable: if self.features: # `IterableDataset` automatically fills missing columns with None. # This is done with `_apply_feature_types_on_example`. @@ -861,7 +932,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False): drop_last_batch (:obj:`bool`, default `False`): Whether a last batch smaller than the batch_size should be dropped """ - iterator = self._iter() + iterator = iter(self._prepare_ex_iterable_for_iteration()) for key, example in iterator: # If batched, first build the batch examples = [example] + [example for key, example in islice(iterator, batch_size - 1)] @@ -880,7 +951,7 @@ def from_generator( generator: Callable, features: Optional[Features] = None, gen_kwargs: Optional[dict] = None, - ): + ) -> "IterableDataset": """Create an Iterable Dataset from a generator. Args: @@ -953,6 +1024,7 @@ def with_format( split=self._split, format_type=type, shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -1064,6 +1136,7 @@ def map( split=self._split, format_type=self._format_type, shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -1136,6 +1209,7 @@ def filter( split=self._split, format_type=self._format_type, shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -1202,6 +1276,7 @@ def shuffle( split=self._split, format_type=self._format_type, shuffling=shuffling, + distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -1243,6 +1318,7 @@ def skip(self, n) -> "IterableDataset": split=self._split, format_type=self._format_type, shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -1274,6 +1350,7 @@ def take(self, n) -> "IterableDataset": split=self._split, format_type=self._format_type, shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -1464,6 +1541,7 @@ def cast_column(self, column: str, feature: FeatureType) -> "IterableDataset": split=self._split, format_type=self._format_type, shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -1514,6 +1592,19 @@ def cast( split=self._split, format_type=self._format_type, shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), + token_per_repo_id=self._token_per_repo_id, + ) + + def _step(self, step: int, offset: int) -> "IterableDataset": + ex_iterable = StepExamplesIterable(self._ex_iterable, step=step, offset=offset) + return IterableDataset( + ex_iterable=ex_iterable, + info=self._info.copy(), + split=self._split, + format_type=self._format_type, + shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -1532,6 +1623,7 @@ def _resolve_features(self): split=self._split, format_type=self._format_type, shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), token_per_repo_id=self._token_per_repo_id, ) @@ -1662,3 +1754,37 @@ def _interleave_iterable_datasets( } # Return new daset return IterableDataset(ex_iterable=ex_iterable, info=info, split=split, token_per_repo_id=token_per_repo_id) + + +def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_size: int) -> IterableDataset: + """ + Split an iterable dataset for the node at rank `rank` in a pool of nodes of size `world_size`. + + If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`), + then the shards are evenly assigned across the nodes, which is the most optimized. + Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples. + + Args: + dataset ([`IterableDataset`]): + The iterable dataset to split by node. + rank (`int`): + Rank of the current node. + world_size (`int`): + Total number of nodes. + + Returns: + [`IterableDataset`]: The iterable dataset to be used on the node at rank `rank`. + """ + if dataset._distributed: + world_size = world_size * dataset._distributed.world_size + rank = world_size * dataset._distributed.rank + rank + distributed = DistributedConfig(rank=rank, world_size=world_size) + return IterableDataset( + ex_iterable=dataset._ex_iterable, + info=dataset._info.copy(), + split=dataset._split, + format_type=dataset._format_type, + shuffling=copy.deepcopy(dataset._shuffling), + distributed=distributed, + token_per_repo_id=dataset._token_per_repo_id, + ) diff --git a/src/datasets/utils/sharding.py b/src/datasets/utils/sharding.py index 52cc0fe04e9..a785929638f 100644 --- a/src/datasets/utils/sharding.py +++ b/src/datasets/utils/sharding.py @@ -68,6 +68,15 @@ def _split_gen_kwargs(gen_kwargs: dict, max_num_jobs: int) -> List[dict]: ] +def _merge_gen_kwargs(gen_kwargs_list: List[dict]) -> dict: + return { + key: [value for gen_kwargs in gen_kwargs_list for value in gen_kwargs[key]] + if isinstance(gen_kwargs_list[0][key], list) + else gen_kwargs_list[0][key] + for key in gen_kwargs_list[0] + } + + def _shuffle_gen_kwargs(rng: np.random.Generator, gen_kwargs: dict) -> dict: """Return a shuffled copy of the input gen_kwargs""" # We must shuffle all the lists, and lists of the same size must have the same shuffling. diff --git a/tests/distributed_scripts/test_torch_distributed_launch.py b/tests/distributed_scripts/test_torch_distributed_launch.py new file mode 100644 index 00000000000..22064178709 --- /dev/null +++ b/tests/distributed_scripts/test_torch_distributed_launch.py @@ -0,0 +1,55 @@ +import os +from argparse import ArgumentParser +from typing import List + +import torch.utils.data + +from datasets import Dataset, IterableDataset +from datasets.distributed import split_dataset_by_node + + +NUM_SHARDS = 4 +NUM_ITEMS_PER_SHARD = 3 + + +class TestFailedError(RuntimeError): + pass + + +def gen(shards: List[str]): + for shard in shards: + for i in range(NUM_ITEMS_PER_SHARD): + yield {"i": i, "shard": shard} + + +def main(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + parser = ArgumentParser() + parser.add_argument("--streaming", type=bool) + parser.add_argument("--local_rank", type=int) + parser.add_argument("--num_workers", type=int, default=0) + args = parser.parse_args() + streaming = args.streaming + num_workers = args.num_workers + + gen_kwargs = {"shards": [f"shard_{shard_idx}" for shard_idx in range(NUM_SHARDS)]} + ds = IterableDataset.from_generator(gen, gen_kwargs=gen_kwargs) + if not streaming: + ds = Dataset.from_list(list(ds)) + + ds = split_dataset_by_node(ds, rank=rank, world_size=world_size) + dataloader = torch.utils.data.DataLoader(ds, num_workers=num_workers) + + full_size = NUM_SHARDS * NUM_ITEMS_PER_SHARD + expected_local_size = full_size // world_size + expected_local_size += int(rank < (full_size % world_size)) + + local_size = sum(1 for _ in dataloader) + if local_size != expected_local_size: + raise TestFailedError(f"local_size {local_size} != expected_local_size {expected_local_size}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_distributed.py b/tests/test_distributed.py new file mode 100644 index 00000000000..2aa6f2809dc --- /dev/null +++ b/tests/test_distributed.py @@ -0,0 +1,104 @@ +import os +import sys +from pathlib import Path + +import pytest + +from datasets import Dataset, IterableDataset +from datasets.distributed import split_dataset_by_node + +from .utils import execute_subprocess_async, get_torch_dist_unique_port, require_torch + + +def test_split_dataset_by_node_map_style(): + full_ds = Dataset.from_dict({"i": range(17)}) + full_size = len(full_ds) + world_size = 3 + datasets_per_rank = [ + split_dataset_by_node(full_ds, rank=rank, world_size=world_size) for rank in range(world_size) + ] + assert sum(len(ds) for ds in datasets_per_rank) == full_size + assert len(set(tuple(x.values()) for ds in datasets_per_rank for x in ds)) == full_size + + +def test_split_dataset_by_node_iterable(): + def gen(): + return ({"i": i} for i in range(17)) + + world_size = 3 + full_ds = IterableDataset.from_generator(gen) + full_size = len(list(full_ds)) + datasets_per_rank = [ + split_dataset_by_node(full_ds, rank=rank, world_size=world_size) for rank in range(world_size) + ] + assert sum(len(list(ds)) for ds in datasets_per_rank) == full_size + assert len(set(tuple(x.values()) for ds in datasets_per_rank for x in ds)) == full_size + + +@pytest.mark.parametrize("shards_per_node", [1, 2, 3]) +def test_split_dataset_by_node_iterable_sharded(shards_per_node): + def gen(shards): + for shard in shards: + yield from ({"i": i, "shard": shard} for i in range(17)) + + world_size = 3 + num_shards = shards_per_node * world_size + gen_kwargs = {"shards": [f"shard_{shard_idx}.txt" for shard_idx in range(num_shards)]} + full_ds = IterableDataset.from_generator(gen, gen_kwargs=gen_kwargs) + full_size = len(list(full_ds)) + assert full_ds.n_shards == world_size * shards_per_node + datasets_per_rank = [ + split_dataset_by_node(full_ds, rank=rank, world_size=world_size) for rank in range(world_size) + ] + assert [ds.n_shards for ds in datasets_per_rank] == [shards_per_node] * world_size + assert sum(len(list(ds)) for ds in datasets_per_rank) == full_size + assert len(set(tuple(x.values()) for ds in datasets_per_rank for x in ds)) == full_size + + +@pytest.mark.parametrize("streaming", [False, True]) +@require_torch +@pytest.mark.skipif(os.name == "nt", reason="execute_subprocess_async doesn't support windows") +@pytest.mark.integration +def test_torch_distributed_launch(streaming): + nproc_per_node = 2 + master_port = get_torch_dist_unique_port() + test_script = Path(__file__).resolve().parent / "distributed_scripts" / "test_torch_distributed_launch.py" + distributed_args = f""" + -m torch.distributed.launch + --nproc_per_node={nproc_per_node} + --master_port={master_port} + {test_script} + """.split() + args = f""" + --streaming={streaming} + """.split() + cmd = [sys.executable] + distributed_args + args + execute_subprocess_async(cmd, env=os.environ.copy()) + + +@pytest.mark.parametrize( + "nproc_per_node, num_workers", + [ + (2, 2), # each node has 2 shards and each worker has 1 shards + (3, 2), # each node uses all the shards but skips examples, and each worker has 2 shards + ], +) +@require_torch +@pytest.mark.skipif(os.name == "nt", reason="execute_subprocess_async doesn't support windows") +@pytest.mark.integration +def test_torch_distributed_launch_streaming_with_num_workers(nproc_per_node, num_workers): + streaming = True + master_port = get_torch_dist_unique_port() + test_script = Path(__file__).resolve().parent / "distributed_scripts" / "test_torch_distributed_launch.py" + distributed_args = f""" + -m torch.distributed.launch + --nproc_per_node={nproc_per_node} + --master_port={master_port} + {test_script} + """.split() + args = f""" + --streaming={streaming} + --num_workers={num_workers} + """.split() + cmd = [sys.executable] + distributed_args + args + execute_subprocess_async(cmd, env=os.environ.copy()) diff --git a/tests/utils.py b/tests/utils.py index dc0fc633371..7168c7f7619 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,6 @@ +import asyncio import os +import re import sys import tempfile import unittest @@ -417,3 +419,110 @@ def _wrapper(func, *args, **kwargs): raise err return decorator.decorator(_wrapper, func) + + +# --- distributed testing functions --- # + +# copied from transformers +# originally adapted from https://stackoverflow.com/a/59041913/9201239 + + +class _RunOutput: + def __init__(self, returncode, stdout, stderr): + self.returncode = returncode + self.stdout = stdout + self.stderr = stderr + + +async def _read_stream(stream, callback): + while True: + line = await stream.readline() + if line: + callback(line) + else: + break + + +async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput: + if echo: + print("\nRunning: ", " ".join(cmd)) + + p = await asyncio.create_subprocess_exec( + cmd[0], + *cmd[1:], + stdin=stdin, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + + # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe + # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait + # + # If it starts hanging, will need to switch to the following code. The problem is that no data + # will be seen until it's done and if it hangs for example there will be no debug info. + # out, err = await p.communicate() + # return _RunOutput(p.returncode, out, err) + + out = [] + err = [] + + def tee(line, sink, pipe, label=""): + line = line.decode("utf-8").rstrip() + sink.append(line) + if not quiet: + print(label, line, file=pipe) + + # XXX: the timeout doesn't seem to make any difference here + await asyncio.wait( + [ + _read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:")), + _read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")), + ], + timeout=timeout, + ) + return _RunOutput(await p.wait(), out, err) + + +def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: + loop = asyncio.get_event_loop() + result = loop.run_until_complete( + _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo) + ) + + cmd_str = " ".join(cmd) + if result.returncode > 0: + stderr = "\n".join(result.stderr) + raise RuntimeError( + f"'{cmd_str}' failed with returncode {result.returncode}\n\n" + f"The combined stderr from workers follows:\n{stderr}" + ) + + # check that the subprocess actually did run and produced some output, should the test rely on + # the remote side to do the testing + if not result.stdout and not result.stderr: + raise RuntimeError(f"'{cmd_str}' produced no output.") + + return result + + +def pytest_xdist_worker_id(): + """ + Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0 + if `-n 1` or `pytest-xdist` isn't being used. + """ + worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0") + worker = re.sub(r"^gw", "", worker, 0, re.M) + return int(worker) + + +def get_torch_dist_unique_port(): + """ + Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument. + + Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same + port at once. + """ + port = 29500 + uniq_delta = pytest_xdist_worker_id() + return port + uniq_delta