Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IterableDataset Arrow formatting #5821

Merged
merged 11 commits into from
May 31, 2023
23 changes: 10 additions & 13 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
update_fingerprint,
validate_fingerprint,
)
from .formatting import PythonFormatter, format_table, get_format_type_from_alias, get_formatter, query_table
from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table
from .formatting.formatting import LazyDict, _is_range_contiguous
from .info import DatasetInfo, DatasetInfosDict
from .naming import _split_re
Expand Down Expand Up @@ -4988,16 +4988,10 @@ def extra_nbytes_visitor(array, feature):
return dataset_nbytes

@staticmethod
def _generate_examples_from_shards(shards: List["Dataset"]):
python_formatter = PythonFormatter()
for shards_idx, shard in enumerate(shards):
example_idx = 0
for pa_table in shard.with_format("arrow").iter(config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER):
batch = python_formatter.format_batch(pa_table)
for i in range(len(pa_table)):
example = {col: array[i] for col, array in batch.items()}
yield f"{shards_idx}_{example_idx}", example
example_idx += 1
def _generate_tables_from_shards(shards: List["Dataset"], batch_size: int):
for shard_idx, shard in enumerate(shards):
for pa_table in shard.with_format("arrow").iter(batch_size):
yield shard_idx, pa_table

def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset":
"""Get an [`datasets.IterableDataset`] from a map-style [`datasets.Dataset`].
Expand Down Expand Up @@ -5090,7 +5084,7 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset
```
Feel free to also use [`IterableDataset.set_epoch`] when using a PyTorch DataLoader or in distributed setups.
"""
from .iterable_dataset import ExamplesIterable, IterableDataset
from .iterable_dataset import ArrowExamplesIterable, IterableDataset

if self._format_type is not None:
raise NotImplementedError(
Expand All @@ -5112,7 +5106,10 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset
self.shard(num_shards=num_shards, index=shard_idx, contiguous=True) for shard_idx in range(num_shards)
]
)
ex_iterable = ExamplesIterable(Dataset._generate_examples_from_shards, kwargs={"shards": shards})
ex_iterable = ArrowExamplesIterable(
Dataset._generate_tables_from_shards,
kwargs={"shards": shards, "batch_size": config.DEFAULT_MAX_BATCH_SIZE},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we should support users to pass a custom batch_size.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea I'm not sure - we can wait for some feedback on this and improve later imo

)
return IterableDataset(ex_iterable, info=DatasetInfo(features=self.features))

def _push_parquet_shards_to_hub(
Expand Down
6 changes: 2 additions & 4 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
)
from .fingerprint import Hasher
from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo
from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper
from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset
from .keyhash import DuplicatedKeysError
from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase
from .splits import Split, SplitDict, SplitGenerator, SplitInfo
Expand Down Expand Up @@ -1897,9 +1897,7 @@ def _prepare_split_single(
yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths)

def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
return ExamplesIterable(
_generate_examples_from_tables_wrapper(self._generate_tables), kwargs=split_generator.gen_kwargs
)
return ArrowExamplesIterable(self._generate_tables, kwargs=split_generator.gen_kwargs)


class MissingBeamOptions(ValueError):
Expand Down
Loading