Skip to content

Commit

Permalink
Distributed support (huggingface#5369)
Browse files Browse the repository at this point in the history
* add split_dataset_by_node

* docs

* tests

* more docs

* update IterableDataset._head()

* don't run integration test on windows

* style

* skip on windows

* add note about contiguous data splitting
  • Loading branch information
lhoestq authored Jan 16, 2023
1 parent 5b793dd commit 9991c74
Show file tree
Hide file tree
Showing 10 changed files with 556 additions and 66 deletions.
2 changes: 2 additions & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table.

[[autodoc]] datasets.interleave_datasets

[[autodoc]] datasets.distributed.split_dataset_by_node

[[autodoc]] datasets.enable_caching

[[autodoc]] datasets.disable_caching
Expand Down
30 changes: 28 additions & 2 deletions docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ You must use a `BatchSampler` if you want the transform to be given full batches
### Stream data

Loading a dataset in streaming mode allows one to iterate over the dataset without downloading it on disk.
An iterable dataset from `datasets` inherits from `torch.utils.data.IterableDataset` so you can pass it to a `DataLoader`:
An iterable dataset from `datasets` inherits from `torch.utils.data.IterableDataset` so you can pass it to a `torch.utils.data.DataLoader`:

```py
>>> import numpy as np
Expand All @@ -222,4 +222,30 @@ If the dataset is split in several shards (i.e. if the dataset consists of multi
>>> dataloader = DataLoader(ds, batch_size=32, num_workers=4)
```

In this case each worker will be given a subset of the list of shards to stream from.
In this case each worker is given a subset of the list of shards to stream from.

### Distributed

To split your dataset across your training nodes, you can use [`datasets.distributed.split_dataset_by_node`]:

```python
import os
from datasets.distributed import split_dataset_by_node

ds = split_dataset_by_node(ds, rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"]))
```

This works for both map-style datasets and iterable datasets.
The dataset is split for the node at rank `rank` in a pool of nodes of size `world_size`.

For map-style datasets:

Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.

For iterable datasets:

If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.

This can also be combined with a `torch.utils.data.DataLoader` if you want each node to use multiple workers to load the data.
20 changes: 20 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5628,6 +5628,26 @@ def iter_random_indices():
return concatenated_datasets.select(indices, **kwargs)


def _split_by_node_map_style_dataset(dataset: Dataset, rank: int, world_size: int) -> Dataset:
"""
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`.
Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.
To maximize data loading throughput, chunks are made of contiguous data on disk if possible.
Args:
dataset ([`Dataset`]):
The dataset to split by node.
rank (`int`):
Rank of the current node.
world_size (`int`):
Total number of nodes.
Returns:
[`Dataset`]: The dataset to be used on the node at rank `rank`.
"""
return dataset.shard(num_shards=world_size, index=rank, contiguous=True)


# This is outside Dataset.filter as it needs to be picklable for multiprocessing


Expand Down
6 changes: 3 additions & 3 deletions src/datasets/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = logging.get_logger(__name__)


DatasetType = TypeVar("DatasetType", "Dataset", "IterableDataset")
DatasetType = TypeVar("DatasetType", Dataset, IterableDataset)


def interleave_datasets(
Expand Down Expand Up @@ -137,11 +137,11 @@ def interleave_datasets(


def concatenate_datasets(
dsets: List[Dataset],
dsets: List[DatasetType],
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
axis: int = 0,
):
) -> DatasetType:
"""
Converts a list of [`Dataset`] with the same schema into a single [`Dataset`].
Expand Down
39 changes: 39 additions & 0 deletions src/datasets/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import TypeVar

from .arrow_dataset import Dataset, _split_by_node_map_style_dataset
from .iterable_dataset import IterableDataset, _split_by_node_iterable_dataset


DatasetType = TypeVar("DatasetType", Dataset, IterableDataset)


def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType:
"""
Split a dataset for the node at rank `rank` in a pool of nodes of size `world_size`.
For map-style datasets:
Each node is assigned a chunk of data, e.g. rank 0 is given the first chunk of the dataset.
To maximize data loading throughput, chunks are made of contiguous data on disk if possible.
For iterable datasets:
If the dataset has a number of shards that is a factor of `world_size` (i.e. if `dataset.n_shards % world_size == 0`),
then the shards are evenly assigned across the nodes, which is the most optimized.
Otherwise, each node keeps 1 example out of `world_size`, skipping the other examples.
Args:
dataset ([`Dataset`] or [`IterableDataset`]):
The dataset to split by node.
rank (`int`):
Rank of the current node.
world_size (`int`):
Total number of nodes.
Returns:
[`Dataset`] or [`IterableDataset`]: The dataset to be used on the node at rank `rank`.
"""
if isinstance(dataset, Dataset):
return _split_by_node_map_style_dataset(dataset, rank=rank, world_size=world_size)
else:
return _split_by_node_iterable_dataset(dataset, rank=rank, world_size=world_size)
Loading

0 comments on commit 9991c74

Please sign in to comment.