diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 5d014962798..c2967649161 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -88,7 +88,7 @@ update_fingerprint, validate_fingerprint, ) -from .formatting import PythonFormatter, format_table, get_format_type_from_alias, get_formatter, query_table +from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table from .formatting.formatting import LazyDict, _is_range_contiguous from .info import DatasetInfo, DatasetInfosDict from .naming import _split_re @@ -4988,16 +4988,10 @@ def extra_nbytes_visitor(array, feature): return dataset_nbytes @staticmethod - def _generate_examples_from_shards(shards: List["Dataset"]): - python_formatter = PythonFormatter() - for shards_idx, shard in enumerate(shards): - example_idx = 0 - for pa_table in shard.with_format("arrow").iter(config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER): - batch = python_formatter.format_batch(pa_table) - for i in range(len(pa_table)): - example = {col: array[i] for col, array in batch.items()} - yield f"{shards_idx}_{example_idx}", example - example_idx += 1 + def _generate_tables_from_shards(shards: List["Dataset"], batch_size: int): + for shard_idx, shard in enumerate(shards): + for pa_table in shard.with_format("arrow").iter(batch_size): + yield shard_idx, pa_table def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset": """Get an [`datasets.IterableDataset`] from a map-style [`datasets.Dataset`]. @@ -5090,7 +5084,7 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset ``` Feel free to also use [`IterableDataset.set_epoch`] when using a PyTorch DataLoader or in distributed setups. """ - from .iterable_dataset import ExamplesIterable, IterableDataset + from .iterable_dataset import ArrowExamplesIterable, IterableDataset if self._format_type is not None: raise NotImplementedError( @@ -5112,7 +5106,10 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset self.shard(num_shards=num_shards, index=shard_idx, contiguous=True) for shard_idx in range(num_shards) ] ) - ex_iterable = ExamplesIterable(Dataset._generate_examples_from_shards, kwargs={"shards": shards}) + ex_iterable = ArrowExamplesIterable( + Dataset._generate_tables_from_shards, + kwargs={"shards": shards, "batch_size": config.DEFAULT_MAX_BATCH_SIZE}, + ) return IterableDataset(ex_iterable, info=DatasetInfo(features=self.features)) def _push_parquet_shards_to_hub( diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 13190244ed1..11027fbb0fa 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -59,7 +59,7 @@ ) from .fingerprint import Hasher from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo -from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper +from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset from .keyhash import DuplicatedKeysError from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase from .splits import Split, SplitDict, SplitGenerator, SplitInfo @@ -1897,9 +1897,7 @@ def _prepare_split_single( yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: - return ExamplesIterable( - _generate_examples_from_tables_wrapper(self._generate_tables), kwargs=split_generator.gen_kwargs - ) + return ArrowExamplesIterable(self._generate_tables, kwargs=split_generator.gen_kwargs) class MissingBeamOptions(ValueError): diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 1781767a2ea..a408fe4f4b8 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -5,7 +5,7 @@ from copy import deepcopy from dataclasses import dataclass from itertools import cycle, islice -from typing import Any, Callable, Dict, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union import numpy as np import pyarrow as pa @@ -18,7 +18,7 @@ from .formatting import PythonFormatter, get_format_type_from_alias from .info import DatasetInfo from .splits import NamedSplit -from .table import table_cast +from .table import cast_table_to_features, table_cast from .utils.logging import get_logger from .utils.py_utils import Literal from .utils.sharding import _merge_gen_kwargs, _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs @@ -26,6 +26,8 @@ logger = get_logger(__name__) +Key = Union[int, str] + def _infer_features_from_batch(batch: Dict[str, list], try_features: Optional[Features] = None) -> Features: pa_table = pa.Table.from_pydict(batch) @@ -82,10 +84,94 @@ def hasnext(self): return self._hasnext +def _convert_to_arrow( + iterable: Iterable[Tuple[Key, dict]], + batch_size: int, + drop_last_batch: bool = False, +) -> Iterator[Tuple[Key, pa.Table]]: + """Convert and group examples in Arrow tables of size `batch_size`. + + Args: + iterable (`Iterable[Tuple[Key, dict]]`): + An examples iterable containing tuples (example_key, example) of type (int/str, dict) + batch_size (`Optional[int]`): + Size of each sub-table to yield. If None or <= 0, yields the full table. + drop_last_batch (`bool`, defaults to `False`): + Drop the last batch if it is smaller than `batch_size`. + """ + if batch_size is None or batch_size <= 0: + yield "all", pa.Table.from_pylist([example for _, example in iterable]) + return + iterator = iter(iterable) + for key, example in iterator: + iterator_batch = islice(iterator, batch_size - 1) + key_examples_list = [(key, example)] + [(key, example) for key, example in iterator_batch] + if len(key_examples_list) < batch_size and drop_last_batch: + return + keys, examples = zip(*key_examples_list) + new_key = "_".join(str(key) for key in keys) + yield new_key, pa.Table.from_pylist(examples) + + +def _batch_arrow_tables( + iterable: Iterable[Tuple[Key, pa.Table]], + batch_size: Optional[int], + drop_last_batch: bool = False, +) -> Iterator[Tuple[Key, pa.Table]]: + """Iterate over sub-tables of size `batch_size`. + + Args: + iterable (`Iterable[Tuple[Key, pa.Table]]`): + A tables iterable containing tuples (table_key, table) of type (int/str, pa.Table) + batch_size (`Optional[int]`): + Size of each sub-table to yield. If None or <= 0, yields the full table. + drop_last_batch (`bool`, defaults to `False`): + Drop the last batch if it is smaller than `batch_size`. + """ + if batch_size is None or batch_size <= 0: + yield "all", pa.concat_tables([pa_table for _, pa_table in iterable]) + return + keys_buffer = [] + chunks_buffer = [] + chunks_buffer_size = 0 + for key, pa_table in iterable: + for chunk in pa_table.to_reader(max_chunksize=batch_size): + if len(chunk) == 0: + continue + elif chunks_buffer_size + len(chunk) < batch_size: + keys_buffer.append(key) + chunks_buffer.append(chunk) + chunks_buffer_size += len(chunk) + continue + elif chunks_buffer_size + len(chunk) == batch_size: + keys_buffer.append(key) + chunks_buffer.append(chunk) + new_key = "_".join(str(_key) for _key in keys_buffer) + yield new_key, pa.Table.from_batches(chunks_buffer) + keys_buffer = [] + chunks_buffer = [] + chunks_buffer_size = 0 + else: + cropped_chunk_length = batch_size - chunks_buffer_size + keys_buffer.append(f"{key}[:{cropped_chunk_length}]") + chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) + new_key = "_".join(str(_key) for _key in keys_buffer) + yield new_key, pa.Table.from_batches(chunks_buffer) + keys_buffer = [f"{key}[{cropped_chunk_length}:]"] + chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] + chunks_buffer_size = len(chunk) - cropped_chunk_length + if not drop_last_batch and chunks_buffer: + new_key = "_".join(str(_key) for _key in keys_buffer) + yield new_key, pa.Table.from_batches(chunks_buffer) + + class _BaseExamplesIterable: """Base class for the examples iterable used by an IterableDataset""" - def __iter__(self): + def __init__(self) -> None: + self.iter_arrow: Optional[Callable[[], Iterator[Tuple[Key, pa.Table]]]] = None + + def __iter__(self) -> Iterator[Tuple[Key, dict]]: """An examples iterable should yield tuples (example_key, example) of type (int/str, dict)""" raise NotImplementedError(f"{type(self)} doesn't implement __iter__ yet") @@ -109,7 +195,8 @@ def n_shards(self) -> int: class ExamplesIterable(_BaseExamplesIterable): - def __init__(self, generate_examples_fn: Callable, kwargs: dict): + def __init__(self, generate_examples_fn: Callable[..., Tuple[Key, dict]], kwargs: dict): + super().__init__() self.generate_examples_fn = generate_examples_fn self.kwargs = kwargs @@ -132,7 +219,9 @@ def n_shards(self) -> int: class ShuffledDataSourcesExamplesIterable(ExamplesIterable): - def __init__(self, generate_examples_fn: Callable, kwargs: dict, generator: np.random.Generator): + def __init__( + self, generate_examples_fn: Callable[..., Tuple[Key, dict]], kwargs: dict, generator: np.random.Generator + ): super().__init__(generate_examples_fn, kwargs) self.generator = deepcopy(generator) @@ -151,15 +240,91 @@ def shard_data_sources(self, worker_id: int, num_workers: int) -> "ExamplesItera ) +class ArrowExamplesIterable(_BaseExamplesIterable): + def __init__(self, generate_tables_fn: Callable[..., Tuple[Key, pa.Table]], kwargs: dict): + super().__init__() + self.generate_tables_fn = generate_tables_fn + self.kwargs = kwargs + self.iter_arrow = self._iter_arrow + + def __iter__(self): + formatter = PythonFormatter() + for key, pa_table in self.generate_tables_fn(**self.kwargs): + for pa_subtable in pa_table.to_reader(max_chunksize=config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER): + formatted_batch = formatter.format_batch(pa_subtable) + for example in _batch_to_examples(formatted_batch): + yield key, example + + def _iter_arrow(self): + yield from self.generate_tables_fn(**self.kwargs) + + def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesIterable": + return ShuffledDataSourcesArrowExamplesIterable(self.generate_tables_fn, self.kwargs, generator) + + def shard_data_sources(self, shard_indices: List[int]) -> "ExamplesIterable": + """Keep only the requested shard.""" + gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards) + requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) + return ArrowExamplesIterable( + self.generate_tables_fn, requested_gen_kwargs, generate_tables_fn=self.generate_tables_fn + ) + + @property + def n_shards(self) -> int: + return _number_of_shards_in_gen_kwargs(self.kwargs) + + +class ShuffledDataSourcesArrowExamplesIterable(ArrowExamplesIterable): + def __init__( + self, + generate_tables_fn: Callable[..., Tuple[Key, pa.Table]], + kwargs: dict, + generator: np.random.Generator, + ): + super().__init__(generate_tables_fn, kwargs) + self.generator = deepcopy(generator) + + def __iter__(self): + """Shuffle the kwargs order to shuffle shards""" + rng = deepcopy(self.generator) + kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) + formatter = PythonFormatter() + for key, pa_table in self.generate_tables_fn(**kwargs_with_shuffled_shards): + for pa_subtable in pa_table.to_reader(max_chunksize=config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER): + formatted_batch = formatter.format_batch(pa_subtable) + for example in _batch_to_examples(formatted_batch): + yield key, example + + def _iter_arrow(self): + rng = deepcopy(self.generator) + kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) + yield from self.generate_tables_fn(**kwargs_with_shuffled_shards) + + def shard_data_sources(self, shard_indices: List[int]) -> "ExamplesIterable": + """Keep only the requested shard.""" + rng = deepcopy(self.generator) + kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) + return ArrowExamplesIterable(self.generate_tables_fn, kwargs_with_shuffled_shards).shard_data_sources( + shard_indices + ) + + class SelectColumnsIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, column_names: List[str]): + super().__init__() self.ex_iterable = ex_iterable self.column_names = column_names + if self.ex_iterable.iter_arrow: + self.iter_arrow = self._iter_arrow def __iter__(self): for idx, row in self.ex_iterable: yield idx, {c: row[c] for c in self.column_names} + def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: + for idx, pa_table in self.ex_iterable.iter_arrow(): + yield idx, pa_table.select(self.column_names) + def shuffle_data_sources(self, generator: np.random.Generator) -> "SelectColumnsIterable": return SelectColumnsIterable(self.ex_iterable.shuffle_data_sources(generator), self.column_names) @@ -173,9 +338,11 @@ def n_shards(self) -> int: class StepExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int): + super().__init__() self.ex_iterable = ex_iterable self.step = step self.offset = offset + # TODO(QL): implement iter_arrow def __iter__(self): ex_iterator = iter(self.ex_iterable) @@ -207,12 +374,14 @@ def __init__( ex_iterables: List[_BaseExamplesIterable], stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", ): + super().__init__() self.ex_iterables = ex_iterables self.stopping_strategy = stopping_strategy # if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted # if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once self.bool_strategy_func = np.all if (stopping_strategy == "all_exhausted") else np.any + # TODO(QL): implement iter_arrow def _get_indices_iterator(self): # this is an infinite iterator to keep track of which iterator we want to pick examples from @@ -278,12 +447,19 @@ class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable): """ def __init__(self, ex_iterables: List[_BaseExamplesIterable]): + super().__init__() self.ex_iterables = ex_iterables + if all(ex_iterable.iter_arrow is not None for ex_iterable in ex_iterables): + self.iter_arrow = self._iter_arrow def __iter__(self): for ex_iterable in self.ex_iterables: yield from ex_iterable + def _iter_arrow(self): + for ex_iterable in self.ex_iterables: + yield from ex_iterable.iter_arrow() + def shuffle_data_sources( self, generator: np.random.Generator ) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": @@ -334,7 +510,9 @@ class HorizontallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable """ def __init__(self, ex_iterables: List[_BaseExamplesIterable]): + super().__init__() self.ex_iterables = ex_iterables + # TODO(QL): implement iter_arrow def __iter__(self): ex_iterators = [iter(ex_iterable) for ex_iterable in self.ex_iterables] @@ -381,7 +559,7 @@ def shard_data_sources( class RandomlyCyclingMultiSourcesExamplesIterable(CyclingMultiSourcesExamplesIterable): def __init__( self, - ex_iterables, + ex_iterables: List[_BaseExamplesIterable], generator: np.random.Generator, probabilities: Optional[List[float]] = None, stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", @@ -389,6 +567,7 @@ def __init__( super().__init__(ex_iterables, stopping_strategy) self.generator = deepcopy(generator) self.probabilities = probabilities + # TODO(QL): implement iter_arrow @staticmethod def _iter_random_indices( @@ -442,7 +621,9 @@ def __init__( drop_last_batch: bool = False, remove_columns: Optional[List[str]] = None, fn_kwargs: Optional[dict] = None, + format_type: Optional[str] = None, ): + super().__init__() self.ex_iterable = ex_iterable self.function = function self.batched = batched @@ -452,8 +633,17 @@ def __init__( self.with_indices = with_indices self.input_columns = input_columns self.fn_kwargs = fn_kwargs or {} + self.format_type = get_format_type_from_alias(format_type) + if format_type == "arrow": + self.iter_arrow = self._iter_arrow def __iter__(self): + if self.format_type == "arrow": + yield from ArrowExamplesIterable(self._iter_arrow, {}) + else: + yield from self._iter() + + def _iter(self): iterator = iter(self.ex_iterable) current_idx = 0 if self.batched: @@ -521,6 +711,44 @@ def __iter__(self): yield key, transformed_example current_idx += 1 + def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: + if self.ex_iterable.iter_arrow: + iterator = _batch_arrow_tables( + self.ex_iterable.iter_arrow(), + batch_size=self.batch_size if self.batched else 1, + drop_last_batch=self.drop_last_batch, + ) + else: + iterator = _convert_to_arrow( + self.ex_iterable, + batch_size=self.batch_size if self.batched else 1, + drop_last_batch=self.drop_last_batch, + ) + current_idx = 0 + for key, pa_table in iterator: + # first build the batch + function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns] + if self.with_indices: + if self.batched: + function_args.append([current_idx + i for i in range(len(pa_table))]) + else: + function_args.append(current_idx) + # then apply the transform + output_table = self.function(*function_args, **self.fn_kwargs) + if not isinstance(output_table, pa.Table): + raise TypeError( + f"Provided `function` which is applied to pyarrow tables returns a variable of type {type(output_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset." + ) + # we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts + # then remove the unwanted columns + if self.remove_columns: + for column in self.remove_columns: + if column in output_table.column_names: + output_table = output_table.remove_column(output_table.column_names.index(column)) + # return output + yield key, output_table + current_idx += len(pa_table) + def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExamplesIterable": """Shuffle the wrapped examples iterable.""" return MappedExamplesIterable( @@ -562,7 +790,9 @@ def __init__( batched: bool = False, batch_size: Optional[int] = 1000, fn_kwargs: Optional[dict] = None, + format_type: Optional[str] = None, ): + super().__init__() self.ex_iterable = ex_iterable self.function = function self.batched = batched @@ -570,8 +800,17 @@ def __init__( self.with_indices = with_indices self.input_columns = input_columns self.fn_kwargs = fn_kwargs or {} + self.format_type = get_format_type_from_alias(format_type) + if format_type == "arrow": + self.iter_arrow = self._iter_arrow def __iter__(self): + if self.format_type == "arrow": + yield from ArrowExamplesIterable(self._iter_arrow, {}) + else: + yield from self._iter() + + def _iter(self): iterator = iter(self.ex_iterable) current_idx = 0 if self.batched: @@ -608,6 +847,31 @@ def __iter__(self): yield key, example current_idx += 1 + def _iter_arrow(self): + if self.ex_iterable.iter_arrow: + iterator = _batch_arrow_tables( + self.ex_iterable.iter_arrow(), batch_size=self.batch_size if self.batched else 1 + ) + else: + iterator = _convert_to_arrow(self.ex_iterable, batch_size=self.batch_size if self.batched else 1) + current_idx = 0 + for key, pa_table in iterator: + # first build the batch + function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns] + if self.with_indices: + if self.batched: + function_args.append([current_idx + i for i in range(len(pa_table))]) + else: + function_args.append(current_idx) + # then apply the transform + mask = self.function(*function_args, **self.fn_kwargs) + # yield the filtered table + if self.batched: + yield key, pa_table.filter(mask) + elif mask.as_py() if isinstance(mask, pa.BooleanScalar) else mask: + yield key, pa_table + current_idx += len(pa_table) + def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable": """Shuffle the wrapped examples iterable.""" return FilteredExamplesIterable( @@ -637,9 +901,11 @@ def n_shards(self) -> int: class BufferShuffledExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator): + super().__init__() self.ex_iterable = ex_iterable self.buffer_size = buffer_size self.generator = generator + # TODO(QL): implement iter_arrow @staticmethod def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batch_size=1000) -> Iterator[int]: @@ -684,8 +950,10 @@ def n_shards(self) -> int: class SkipExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, n: int): + super().__init__() self.ex_iterable = ex_iterable self.n = n + # TODO(QL): implement iter_arrow def __iter__(self): yield from islice(self.ex_iterable, self.n, None) @@ -701,8 +969,10 @@ def n_shards(self) -> int: class TakeExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, n: int): + super().__init__() self.ex_iterable = ex_iterable self.n = n + # TODO(QL): implement iter_arrow def __iter__(self): yield from islice(self.ex_iterable, self.n) @@ -770,9 +1040,12 @@ def __init__( features: Features, token_per_repo_id: Dict[str, Union[str, bool, None]], ): + super().__init__() self.ex_iterable = ex_iterable self.features = features self.token_per_repo_id = token_per_repo_id + if self.ex_iterable.iter_arrow is not None: + self.iter_arrow = self._iter_arrow def __iter__(self): # Then for each example, `TypedExamplesIterable` automatically fills missing columns with None. @@ -782,6 +1055,19 @@ def __iter__(self): example, self.features, token_per_repo_id=self.token_per_repo_id ) + def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: + schema = self.features.arrow_schema + for key, pa_table in self.ex_iterable.iter_arrow(): + columns = set(pa_table.column_names) + # add missing columns + for column_name in self.features: + if column_name not in columns: + col = pa.NullArray.from_buffers(pa.null(), len(pa_table), [None]) + pa_table = pa_table.append_column(column_name, col) + if pa_table.schema != schema: + pa_table = cast_table_to_features(pa_table, self.features) + yield key, pa_table + def shuffle_data_sources(self, generator: np.random.Generator) -> "TypedExamplesIterable": """Shuffle the wrapped examples iterable.""" return TypedExamplesIterable( @@ -803,17 +1089,6 @@ def n_shards(self) -> int: return self.ex_iterable.n_shards -def _generate_examples_from_tables_wrapper(generate_tables_fn): - def wrapper(**kwargs): - python_formatter = PythonFormatter() - for key, table in generate_tables_fn(**kwargs): - batch = python_formatter.format_batch(table) - for i, example in enumerate(_batch_to_examples(batch)): - yield f"{key}_{i}", example - - return wrapper - - @dataclass class ShufflingConfig: generator: np.random.Generator @@ -975,6 +1250,10 @@ def _prepare_ex_iterable_for_iteration(self) -> _BaseExamplesIterable: return ex_iterable def __iter__(self): + if self._format_type == "arrow": + yield from self.iter(batch_size=1) + return + ex_iterable = self._prepare_ex_iterable_for_iteration() if "torch" in sys.modules: @@ -1004,7 +1283,27 @@ def iter(self, batch_size: int, drop_last_batch: bool = False): drop_last_batch (:obj:`bool`, default `False`): Whether a last batch smaller than the batch_size should be dropped """ - iterator = iter(self._prepare_ex_iterable_for_iteration()) + ex_iterable = self._prepare_ex_iterable_for_iteration() + if self._format_type == "arrow": + if ex_iterable.iter_arrow: + iterator = _batch_arrow_tables( + ex_iterable.iter_arrow(), batch_size=batch_size, drop_last_batch=drop_last_batch + ) + else: + iterator = _convert_to_arrow(ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch) + for key, pa_table in iterator: + if self.features: + columns = set(pa_table.colum_names) + # add missing columns + for column_name in self.features: + if column_name not in columns: + col = pa.NullArray.from_buffers(pa.null(), len(pa_table), [None]) + pa_table = pa_table.append_column(column_name, col) + yield cast_table_to_features(pa_table, self.features) + else: + yield pa_table + return + iterator = iter(ex_iterable) for key, example in iterator: # If batched, first build the batch examples = [example] + [example for key, example in islice(iterator, batch_size - 1)] @@ -1121,7 +1420,8 @@ def with_format( ) -> "IterableDataset": """ Return a dataset with the specified format. - This method only supports the "torch" format for now. + Supported formats: "arrow", or None for regular python objects. + The other formats are currently not implemented. Args: @@ -1242,6 +1542,7 @@ def map( drop_last_batch=drop_last_batch, remove_columns=remove_columns, fn_kwargs=fn_kwargs, + format_type=self._format_type, ) info = self.info.copy() info.features = features @@ -1321,6 +1622,7 @@ def filter( batched=batched, batch_size=batch_size, fn_kwargs=fn_kwargs, + format_type=self._format_type, ) return IterableDataset( ex_iterable=ex_iterable, diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 8b1fb921b6d..bd6c7e4c67e 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -3,6 +3,8 @@ import numpy as np import pandas as pd +import pyarrow as pa +import pyarrow.compute as pc import pytest from datasets import load_dataset @@ -16,6 +18,7 @@ from datasets.formatting import get_format_type_from_alias from datasets.info import DatasetInfo from datasets.iterable_dataset import ( + ArrowExamplesIterable, BufferShuffledExamplesIterable, CyclingMultiSourcesExamplesIterable, ExamplesIterable, @@ -24,11 +27,19 @@ IterableDataset, MappedExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable, + SelectColumnsIterable, + ShuffledDataSourcesArrowExamplesIterable, + ShuffledDataSourcesExamplesIterable, ShufflingConfig, SkipExamplesIterable, + StepExamplesIterable, TakeExamplesIterable, + TypedExamplesIterable, VerticallyConcatenatedMultiSourcesExamplesIterable, + _BaseExamplesIterable, + _batch_arrow_tables, _batch_to_examples, + _convert_to_arrow, _examples_to_batch, ) @@ -42,6 +53,7 @@ DEFAULT_N_EXAMPLES = 20 +DEFAULT_BATCH_SIZE = 4 DEFAULT_FILEPATH = "file.txt" SAMPLE_DATASET_IDENTIFIER = "lhoestq/test" # has dataset script @@ -58,6 +70,25 @@ def generate_examples_fn(**kwargs): yield f"{filepath}_{i}", {"id": i, **kwargs} +def generate_tables_fn(**kwargs): + kwargs = kwargs.copy() + n = kwargs.pop("n", DEFAULT_N_EXAMPLES) + batch_size = kwargs.pop("batch_size", DEFAULT_BATCH_SIZE) + filepaths = kwargs.pop("filepaths", None) + for filepath in filepaths or [DEFAULT_FILEPATH]: + buffer = [] + batch_idx = 0 + if filepaths is not None: + kwargs["filepath"] = filepath + for i in range(n): + buffer.append({"id": i, **kwargs}) + if len(buffer) == batch_size: + yield f"{filepath}_{batch_idx}", pa.Table.from_pylist(buffer) + buffer = [] + batch_idx += 1 + yield batch_idx, pa.Table.from_pylist(buffer) + + @pytest.fixture def dataset(): ex_iterable = ExamplesIterable(generate_examples_fn, {}) @@ -73,6 +104,68 @@ def dataset_with_several_columns(): return IterableDataset(ex_iterable, info=DatasetInfo(description="dummy"), split="train") +################################ +# +# Utilities tests +# +################################ + + +@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20]) +@pytest.mark.parametrize("drop_last_batch", [False, True]) +def test_convert_to_arrow(batch_size, drop_last_batch): + examples = [{"foo": i} for i in range(10)] + full_table = pa.Table.from_pylist(examples) + num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size + num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size + subtables = list( + _convert_to_arrow( + [(i, example) for i, example in enumerate(examples)], + batch_size=batch_size, + drop_last_batch=drop_last_batch, + ) + ) + assert len(subtables) == num_batches + if drop_last_batch: + assert all(len(subtable) == batch_size for _, subtable in subtables) + else: + assert all(len(subtable) == batch_size for _, subtable in subtables[:-1]) + assert len(subtables[-1][1]) <= batch_size + if num_rows > 0: + reloaded = pa.concat_tables([subtable for _, subtable in subtables]) + assert full_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() + + +@pytest.mark.parametrize( + "tables", + [ + [pa.table({"foo": range(10)})], + [pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})], + [pa.table({"foo": [i]}) for i in range(10)], + ], +) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20]) +@pytest.mark.parametrize("drop_last_batch", [False, True]) +def test_batch_arrow_tables(tables, batch_size, drop_last_batch): + full_table = pa.concat_tables(tables) + num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size + num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size + subtables = list( + _batch_arrow_tables( + [(i, table) for i, table in enumerate(tables)], batch_size=batch_size, drop_last_batch=drop_last_batch + ) + ) + assert len(subtables) == num_batches + if drop_last_batch: + assert all(len(subtable) == batch_size for _, subtable in subtables) + else: + assert all(len(subtable) == batch_size for _, subtable in subtables[:-1]) + assert len(subtables[-1][1]) <= batch_size + if num_rows > 0: + reloaded = pa.concat_tables([subtable for _, subtable in subtables]) + assert full_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() + + ################################ # # _BaseExampleIterable tests @@ -85,6 +178,7 @@ def test_examples_iterable(): expected = list(generate_examples_fn()) assert next(iter(ex_iterable)) == expected[0] assert list(ex_iterable) == expected + assert ex_iterable.iter_arrow is None def test_examples_iterable_with_kwargs(): @@ -121,6 +215,38 @@ def gen(filepaths, all_metadata): assert filepaths_ids == metadata_ids, "entangled lists of shards/metadata should be shuffled the same way" +def test_arrow_examples_iterable(): + ex_iterable = ArrowExamplesIterable(generate_tables_fn, {}) + expected = sum([pa_table.to_pylist() for _, pa_table in generate_tables_fn()], []) + assert next(iter(ex_iterable))[1] == expected[0] + assert [example for _, example in ex_iterable] == expected + expected = list(generate_tables_fn()) + assert list(ex_iterable.iter_arrow()) == expected + + +def test_arrow_examples_iterable_with_kwargs(): + ex_iterable = ArrowExamplesIterable(generate_tables_fn, {"filepaths": ["0.txt", "1.txt"], "split": "train"}) + expected = sum( + [pa_table.to_pylist() for _, pa_table in generate_tables_fn(filepaths=["0.txt", "1.txt"], split="train")], [] + ) + assert [example for _, example in ex_iterable] == expected + assert all("split" in ex for _, ex in ex_iterable) + assert sorted({ex["filepath"] for _, ex in ex_iterable}) == ["0.txt", "1.txt"] + expected = list(generate_tables_fn(filepaths=["0.txt", "1.txt"], split="train")) + assert list(ex_iterable.iter_arrow()) == expected + + +def test_arrow_examples_iterable_shuffle_data_sources(): + ex_iterable = ArrowExamplesIterable(generate_tables_fn, {"filepaths": ["0.txt", "1.txt"]}) + ex_iterable = ex_iterable.shuffle_data_sources(np.random.default_rng(40)) + expected = sum( + [pa_table.to_pylist() for _, pa_table in generate_tables_fn(filepaths=["1.txt", "0.txt"])], [] + ) # shuffle the filepaths + assert [example for _, example in ex_iterable] == expected + expected = list(generate_tables_fn(filepaths=["1.txt", "0.txt"])) + assert list(ex_iterable.iter_arrow()) == expected + + @pytest.mark.parametrize("seed", [42, 1337, 101010, 123456]) def test_buffer_shuffled_examples_iterable(seed): n, buffer_size = 100, 30 @@ -446,6 +572,254 @@ def test_mapped_examples_iterable_input_columns(n, func, batched, batch_size, in assert [x for _, x in ex_iterable] == expected +@pytest.mark.parametrize( + "n, func, batched, batch_size", + [ + (3, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), False, None), # just add 1 to the id + (3, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 1), # same with bs=1 + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10), # same with bs=10 + (25, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10), # same with bs=10 + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, None), # same with bs=None + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, -1), # same with bs<=0 + (3, lambda t: pa.concat_tables([t] * 2), True, 1), # make a duplicate of each example + ], +) +def test_mapped_examples_iterable_arrow_format(n, func, batched, batch_size): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, func, batched=batched, batch_size=batch_size, format_type="arrow" + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if batched is False: + expected = [func(pa.Table.from_pylist([x])).to_pylist()[0] for x in all_examples] + else: + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend(func(batch).to_pylist()) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + + +@pytest.mark.parametrize( + "n, func, batched, batch_size", + [ + (3, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), False, None), # just add 1 to the id + (3, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 1), # same with bs=1 + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10), # same with bs=10 + (25, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10), # same with bs=10 + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, None), # same with bs=None + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, -1), # same with bs<=0 + (3, lambda t: pa.concat_tables([t] * 2), True, 1), # make a duplicate of each example + ], +) +def test_mapped_examples_iterable_drop_last_batch_and_arrow_format(n, func, batched, batch_size): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, func, batched=batched, batch_size=batch_size, drop_last_batch=True, format_type="arrow" + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + is_empty = False + if batched is False: + # `drop_last_batch` has no effect here + expected = [func(pa.Table.from_pylist([x])).to_pylist()[0] for x in all_examples] + else: + all_transformed_examples = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + if len(examples) < batch_size: # ignore last batch + break + batch = pa.Table.from_pylist(examples) + out = func(batch) + all_transformed_examples.extend( + out.to_pylist() + ) # we don't merge with input since they're arrow tables and not dictionaries + all_examples = all_examples if n % batch_size == 0 else all_examples[: n // batch_size * batch_size] + if all_examples: + expected = all_transformed_examples + else: + is_empty = True + + if not is_empty: + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + else: + with pytest.raises(StopIteration): + next(iter(ex_iterable)) + + +@pytest.mark.parametrize( + "n, func, batched, batch_size", + [ + ( + 3, + lambda t, index: t.append_column("id+idx", pc.add(t["id"], index)), + False, + None, + ), # add the index to the id + ( + 25, + lambda t, indices: t.append_column("id+idx", pc.add(t["id"], indices)), + True, + 10, + ), # add the index to the id + (5, lambda t, indices: t.append_column("id+idx", pc.add(t["id"], indices)), True, None), # same with bs=None + (5, lambda t, indices: t.append_column("id+idx", pc.add(t["id"], indices)), True, -1), # same with bs<=0 + ], +) +def test_mapped_examples_iterable_with_indices_and_arrow_format(n, func, batched, batch_size): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, func, batched=batched, batch_size=batch_size, with_indices=True, format_type="arrow" + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if batched is False: + expected = [func(pa.Table.from_pylist([x]), i).to_pylist()[0] for i, x in enumerate(all_examples)] + else: + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend(func(batch, list(range(batch_offset, batch_offset + len(batch)))).to_pylist()) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + + +@pytest.mark.parametrize( + "n, func, batched, batch_size, remove_columns", + [ + ( + 3, + lambda t: t.append_column("id+1", pc.add(t["id"], 1)), + False, + None, + ["extra_column"], + ), # just add 1 to the id + (25, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10, ["extra_column"]), # same with bs=10 + ( + 50, + lambda t: pa.table({"foo": ["bar"] * np.random.default_rng(t["id"][0].as_py()).integers(0, 10)}), + True, + 8, + ["extra_column", "id"], + ), # make a duplicate of each example + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, None, ["extra_column"]), # same with bs=None + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, -1, ["extra_column"]), # same with bs<=0 + ], +) +def test_mapped_examples_iterable_remove_columns_arrow_format(n, func, batched, batch_size, remove_columns): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "extra_column": "foo"}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, + func, + batched=batched, + batch_size=batch_size, + remove_columns=remove_columns, + format_type="arrow", + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + columns_to_remove = remove_columns if isinstance(remove_columns, list) else [remove_columns] + if batched is False: + expected = [ + {**{k: v for k, v in func(pa.Table.from_pylist([x])).to_pylist()[0].items() if k not in columns_to_remove}} + for x in all_examples + ] + else: + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend( + [{k: v for k, v in x.items() if k not in columns_to_remove} for x in func(batch).to_pylist()] + ) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + + +@pytest.mark.parametrize( + "n, func, batched, batch_size, fn_kwargs", + [ + (3, lambda t, y=0: t.append_column("id+idx", pc.add(t["id"], y)), False, None, None), + (3, lambda t, y=0: t.append_column("id+idx", pc.add(t["id"], y)), False, None, {"y": 3}), + (25, lambda t, y=0: t.append_column("id+idx", pc.add(t["id"], y)), True, 10, {"y": 3}), + (5, lambda t, y=0: t.append_column("id+idx", pc.add(t["id"], y)), True, None, {"y": 3}), # same with bs=None + (5, lambda t, y=0: t.append_column("id+idx", pc.add(t["id"], y)), True, -1, {"y": 3}), # same with bs<=0 + ], +) +def test_mapped_examples_iterable_fn_kwargs_and_arrow_format(n, func, batched, batch_size, fn_kwargs): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, func, batched=batched, batch_size=batch_size, fn_kwargs=fn_kwargs, format_type="arrow" + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if fn_kwargs is None: + fn_kwargs = {} + if batched is False: + expected = [func(pa.Table.from_pylist([x]), **fn_kwargs).to_pylist()[0] for x in all_examples] + else: + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend(func(batch, **fn_kwargs).to_pylist()) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + + +@pytest.mark.parametrize( + "n, func, batched, batch_size, input_columns", + [ + (3, lambda id_: pa.table({"id+1": pc.add(id_, 1)}), False, None, ["id"]), # just add 1 to the id + (25, lambda ids_: pa.table({"id+1": pc.add(ids_, 1)}), True, 10, ["id"]), # same with bs=10 + (5, lambda ids_: pa.table({"id+1": pc.add(ids_, 1)}), True, None, ["id"]), # same with bs=None + (5, lambda ids_: pa.table({"id+1": pc.add(ids_, 1)}), True, -1, ["id"]), # same with bs<=0 + ], +) +def test_mapped_examples_iterable_input_columns_and_arrow_format(n, func, batched, batch_size, input_columns): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, + func, + batched=batched, + batch_size=batch_size, + input_columns=input_columns, + format_type="arrow", + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + columns_to_input = input_columns if isinstance(input_columns, list) else [input_columns] + if batched is False: + expected = [ + func(*[pa.Table.from_pylist([x])[col] for col in columns_to_input]).to_pylist()[0] for x in all_examples + ] + else: + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend(func(*[batch[col] for col in columns_to_input]).to_pylist()) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + + @pytest.mark.parametrize( "n, func, batched, batch_size", [ @@ -614,6 +988,64 @@ def test_horizontally_concatenated_examples_iterable(): ), "horizontally concatenated examples makes the shards order fixed" +@pytest.mark.parametrize( + "ex_iterable", + [ + ExamplesIterable(generate_examples_fn, {}), + ShuffledDataSourcesExamplesIterable(generate_examples_fn, {}, np.random.default_rng(42)), + SelectColumnsIterable(ExamplesIterable(generate_examples_fn, {}), ["id"]), + StepExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 2, 0), + CyclingMultiSourcesExamplesIterable([ExamplesIterable(generate_examples_fn, {})]), + VerticallyConcatenatedMultiSourcesExamplesIterable([ExamplesIterable(generate_examples_fn, {})]), + HorizontallyConcatenatedMultiSourcesExamplesIterable([ExamplesIterable(generate_examples_fn, {})]), + RandomlyCyclingMultiSourcesExamplesIterable( + [ExamplesIterable(generate_examples_fn, {})], np.random.default_rng(42) + ), + MappedExamplesIterable(ExamplesIterable(generate_examples_fn, {}), lambda x: x), + MappedExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda x: x), + FilteredExamplesIterable(ExamplesIterable(generate_examples_fn, {}), lambda x: True), + FilteredExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda x: True), + BufferShuffledExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10, np.random.default_rng(42)), + SkipExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10), + TakeExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10), + TypedExamplesIterable( + ExamplesIterable(generate_examples_fn, {}), Features({"id": Value("int32")}), token_per_repo_id={} + ), + ], +) +def test_no_iter_arrow(ex_iterable: _BaseExamplesIterable): + assert ex_iterable.iter_arrow is None + + +@pytest.mark.parametrize( + "ex_iterable", + [ + ArrowExamplesIterable(generate_tables_fn, {}), + ShuffledDataSourcesArrowExamplesIterable(generate_tables_fn, {}, np.random.default_rng(42)), + SelectColumnsIterable(ArrowExamplesIterable(generate_tables_fn, {}), ["id"]), + # StepExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), 2, 0), # not implemented + # CyclingMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})]), # not implemented + VerticallyConcatenatedMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})]), + # HorizontallyConcatenatedMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})]), # not implemented + # RandomlyCyclingMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})], np.random.default_rng(42)), # not implemented + MappedExamplesIterable(ExamplesIterable(generate_examples_fn, {}), lambda t: t, format_type="arrow"), + MappedExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda t: t, format_type="arrow"), + FilteredExamplesIterable(ExamplesIterable(generate_examples_fn, {}), lambda t: True, format_type="arrow"), + FilteredExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda t: True, format_type="arrow"), + # BufferShuffledExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), 10, np.random.default_rng(42)), # not implemented + # SkipExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), 10), # not implemented + # TakeExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), 10), # not implemented + TypedExamplesIterable( + ArrowExamplesIterable(generate_tables_fn, {}), Features({"id": Value("int32")}), token_per_repo_id={} + ), + ], +) +def test_iter_arrow(ex_iterable: _BaseExamplesIterable): + assert ex_iterable.iter_arrow is not None + key, pa_table = next(ex_iterable.iter_arrow()) + assert isinstance(pa_table, pa.Table) + + ############################ # # IterableDataset tests