From b860cf62a698b06be8e5f20e1902c1dccb526cad Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 2 May 2023 18:26:25 +0200 Subject: [PATCH 01/10] add iterable arrow formatting --- src/datasets/arrow_dataset.py | 23 ++- src/datasets/iterable_dataset.py | 295 ++++++++++++++++++++++++++++++- 2 files changed, 301 insertions(+), 17 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index bb3d067d34e..69661a94f9a 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -88,7 +88,7 @@ update_fingerprint, validate_fingerprint, ) -from .formatting import PythonFormatter, format_table, get_format_type_from_alias, get_formatter, query_table +from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table from .formatting.formatting import LazyDict, _is_range_contiguous from .info import DatasetInfo, DatasetInfosDict from .naming import _split_re @@ -4933,16 +4933,10 @@ def extra_nbytes_visitor(array, feature): return dataset_nbytes @staticmethod - def _generate_examples_from_shards(shards: List["Dataset"]): - python_formatter = PythonFormatter() - for shards_idx, shard in enumerate(shards): - example_idx = 0 - for pa_table in shard.with_format("arrow").iter(config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER): - batch = python_formatter.format_batch(pa_table) - for i in range(len(pa_table)): - example = {col: array[i] for col, array in batch.items()} - yield f"{shards_idx}_{example_idx}", example - example_idx += 1 + def _generate_tables_from_shards(shards: List["Dataset"], batch_size: int): + for shard_idx, shard in enumerate(shards): + for pa_table in shard.with_format("arrow").iter(batch_size): + yield shard_idx, pa_table def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset": """Get an [`datasets.IterableDataset`] from a map-style [`datasets.Dataset`]. @@ -5035,7 +5029,7 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset ``` Feel free to also use [`IterableDataset.set_epoch`] when using a PyTorch DataLoader or in distributed setups. """ - from .iterable_dataset import ExamplesIterable, IterableDataset + from .iterable_dataset import ArrowExamplesIterable, IterableDataset if self._format_type is not None: raise NotImplementedError( @@ -5057,7 +5051,10 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset self.shard(num_shards=num_shards, index=shard_idx, contiguous=True) for shard_idx in range(num_shards) ] ) - ex_iterable = ExamplesIterable(Dataset._generate_examples_from_shards, kwargs={"shards": shards}) + ex_iterable = ArrowExamplesIterable( + Dataset._generate_tables_from_shards, + kwargs={"shards": shards, "batch_size": config.DEFAULT_MAX_BATCH_SIZE}, + ) return IterableDataset(ex_iterable, info=DatasetInfo(features=self.features)) def _push_parquet_shards_to_hub( diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 83465dd85f5..9a46e7fa84b 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -5,7 +5,7 @@ from copy import deepcopy from dataclasses import dataclass from itertools import cycle, islice -from typing import Any, Callable, Dict, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import numpy as np import pyarrow as pa @@ -18,7 +18,7 @@ from .formatting import PythonFormatter, get_format_type_from_alias from .info import DatasetInfo from .splits import NamedSplit -from .table import table_cast +from .table import cast_table_to_features, table_cast from .utils.logging import get_logger from .utils.sharding import _merge_gen_kwargs, _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs @@ -81,9 +81,59 @@ def hasnext(self): return self._hasnext +def _batch_arrow_tables( + iterator: Iterator[Tuple[Any, pa.Table]], + batch_size: int, + drop_last_batch=False, +) -> Iterator[Tuple[Any, pa.Table]]: + """Iterate over sub-tables of size `batch_size`. + + Args: + batch_size (`int`): + Size of each sub-table to yield. + drop_last_batch (`bool`, defaults to `False`): + Drop the last batch if it is smaller than `batch_size`. + """ + keys_buffer = [] + chunks_buffer = [] + chunks_buffer_size = 0 + for key, pa_table in iterator: + for chunk in pa_table.to_reader(max_chunksize=batch_size): + if len(chunk) == 0: + continue + elif chunks_buffer_size + len(chunk) < batch_size: + keys_buffer.append(key) + chunks_buffer.append(chunk) + chunks_buffer_size += len(chunk) + continue + elif chunks_buffer_size + len(chunk) == batch_size: + keys_buffer.append(key) + chunks_buffer.append(chunk) + new_key = "_".join(str(_key) for _key in keys_buffer) + yield new_key, pa.Table.from_batches(chunks_buffer) + keys_buffer = [] + chunks_buffer = [] + chunks_buffer_size = 0 + else: + cropped_chunk_length = batch_size - chunks_buffer_size + keys_buffer.append(f"{key}[:{cropped_chunk_length}]") + chunks_buffer.append(chunk.slice(0, cropped_chunk_length)) + new_key = "_".join(str(_key) for _key in keys_buffer) + yield new_key, pa.Table.from_batches(chunks_buffer) + keys_buffer = [f"{key}[{cropped_chunk_length}:]"] + chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)] + chunks_buffer_size = len(chunk) - cropped_chunk_length + if not drop_last_batch and chunks_buffer: + new_key = "_".join(str(_key) for _key in keys_buffer) + yield new_key, pa.Table.from_batches(chunks_buffer) + + class _BaseExamplesIterable: """Base class for the examples iterable used by an IterableDataset""" + def __init__(self) -> None: + self.iter_arrow: Optional[Callable[[], Iterator[Tuple[Any, pa.Table]]]] = None + def __iter__(self): """An examples iterable should yield tuples (example_key, example) of type (int/str, dict)""" raise NotImplementedError(f"{type(self)} doesn't implement __iter__ yet") @@ -106,6 +156,7 @@ def n_shards(self) -> int: class ExamplesIterable(_BaseExamplesIterable): def __init__(self, generate_examples_fn: Callable, kwargs: dict): + super().__init__() self.generate_examples_fn = generate_examples_fn self.kwargs = kwargs @@ -146,15 +197,121 @@ def shard_data_sources(self, shard_indices: List[int]) -> "ExamplesIterable": ) +class ArrowExamplesIterable(_BaseExamplesIterable): + def __init__(self, generate_tables_fn: Callable, kwargs: dict): + super().__init__() + self.generate_tables_fn = generate_tables_fn + self.kwargs = kwargs + self.iter_arrow = self._iter_arrow + + def __iter__(self): + formatter = PythonFormatter() + for key, pa_table in self.generate_tables_fn(**self.kwargs): + for pa_subtable in pa_table.to_reader(max_chunksize=config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER): + formatted_batch = formatter.format_batch(pa_subtable) + for example in _batch_to_examples(formatted_batch): + yield key, example + + def _iter_arrow(self): + yield from self.generate_tables_fn(**self.kwargs) + + def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesIterable": + return ShuffledDataSourcesArrowExamplesIterable(self.generate_tables_fn, self.kwargs, generator) + + def shard_data_sources(self, shard_indices: List[int]) -> "ExamplesIterable": + """Keep only the requested shard.""" + gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards) + requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) + return ArrowExamplesIterable( + self.generate_tables_fn, requested_gen_kwargs, generate_tables_fn=self.generate_tables_fn + ) + + @property + def n_shards(self) -> int: + return _number_of_shards_in_gen_kwargs(self.kwargs) + + +class ShuffledDataSourcesArrowExamplesIterable(ArrowExamplesIterable): + def __init__( + self, + generate_tables_fn: Callable, + kwargs: dict, + generator: np.random.Generator, + ): + super().__init__(generate_tables_fn, kwargs) + self.generator = deepcopy(generator) + + def __iter__(self): + """Shuffle the kwargs order to shuffle shards""" + rng = deepcopy(self.generator) + kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) + formatter = PythonFormatter() + for key, pa_table in self.generate_tables_fn(**kwargs_with_shuffled_shards): + for pa_subtable in pa_table.to_reader(max_chunksize=config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER): + formatted_batch = formatter.format_batch(pa_subtable) + for example in _batch_to_examples(formatted_batch): + yield key, example + + def _iter_arrow(self): + rng = deepcopy(self.generator) + kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) + yield from self.generate_tables_fn(**kwargs_with_shuffled_shards) + + def shard_data_sources(self, shard_indices: List[int]) -> "ExamplesIterable": + """Keep only the requested shard.""" + rng = deepcopy(self.generator) + kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) + return ArrowExamplesIterable(self.generate_tables_fn, kwargs_with_shuffled_shards).shard_data_sources( + shard_indices + ) + + +class PythonToArrowExamplesIterable(_BaseExamplesIterable): + def __init__(self, ex_iterable: _BaseExamplesIterable): + super().__init__() + self.ex_iterable = ex_iterable + self.iter_arrow = self._iter_arrow + + def __iter__(self): + yield from self.ex_iterable + + def _iter_arrow(self) -> Iterator[Tuple[Any, pa.Table]]: + iterator = iter(self.ex_iterable) + batch_size = config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER + for key, example in iterator: + iterator_batch = islice(iterator, batch_size - 1) + key_examples_list = [(key, example)] + [(key, example) for key, example in iterator_batch] + keys, examples = zip(*key_examples_list) + new_key = "_".join(str(key) for key in keys) + yield new_key, pa.Table.from_pylist(examples) + + def shuffle_data_sources(self, generator: np.random.Generator) -> "PythonToArrowExamplesIterable": + return PythonToArrowExamplesIterable(self.ex_iterable.shuffle_data_sources(generator)) + + def shard_data_sources(self, shard_indices: List[int]) -> "PythonToArrowExamplesIterable": + return PythonToArrowExamplesIterable(self.ex_iterable.shard_data_sources(shard_indices)) + + @property + def n_shards(self) -> int: + return self.ex_iterable.n_shards + + class SelectColumnsIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, column_names: List[str]): + super().__init__() self.ex_iterable = ex_iterable self.column_names = column_names + if self.ex_iterable.iter_arrow: + self.iter_arrow = self._iter_arrow def __iter__(self): for idx, row in self.ex_iterable: yield idx, {c: row[c] for c in self.column_names} + def _iter_arrow(self) -> Iterator[Tuple[Any, pa.Table]]: + for idx, pa_table in self.ex_iterable.iter_arrow(): + yield idx, pa_table.select(self.column_names) + def shuffle_data_sources(self, generator: np.random.Generator) -> "SelectColumnsIterable": return SelectColumnsIterable(self.ex_iterable.shuffle_data_sources(generator), self.column_names) @@ -168,9 +325,11 @@ def n_shards(self) -> int: class StepExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int): + super().__init__() self.ex_iterable = ex_iterable self.step = step self.offset = offset + # TODO(QL): implement iter_arrow def __iter__(self): ex_iterator = iter(self.ex_iterable) @@ -200,12 +359,14 @@ class CyclingMultiSourcesExamplesIterable(_BaseExamplesIterable): def __init__( self, ex_iterables: List[_BaseExamplesIterable], stopping_strategy: Optional[str] = "first_exhausted" ): + super().__init__() self.ex_iterables = ex_iterables self.stopping_strategy = stopping_strategy # if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted # if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once self.bool_strategy_func = np.all if (stopping_strategy == "all_exhausted") else np.any + # TODO(QL): implement iter_arrow def _give_indice_iterator(self): # this is an infinite iterator to keep track of which iterator we want to pick examples from @@ -268,7 +429,9 @@ class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable): """ def __init__(self, ex_iterables: List[_BaseExamplesIterable]): + super().__init__() self.ex_iterables = ex_iterables + # TODO(QL): implement iter_arrow def __iter__(self): for ex_iterable in self.ex_iterables: @@ -320,7 +483,9 @@ class HorizontallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable """ def __init__(self, ex_iterables: List[_BaseExamplesIterable]): + super().__init__() self.ex_iterables = ex_iterables + # TODO(QL): implement iter_arrow def __iter__(self): ex_iterators = [iter(ex_iterable) for ex_iterable in self.ex_iterables] @@ -363,7 +528,7 @@ def shard_data_sources(self, shard_idx: int) -> "HorizontallyConcatenatedMultiSo class RandomlyCyclingMultiSourcesExamplesIterable(CyclingMultiSourcesExamplesIterable): def __init__( self, - ex_iterables, + ex_iterables: List[_BaseExamplesIterable], generator: np.random.Generator, probabilities: Optional[List[float]] = None, stopping_strategy: Optional[str] = "first_exhausted", @@ -371,6 +536,7 @@ def __init__( super().__init__(ex_iterables, stopping_strategy) self.generator = deepcopy(generator) self.probabilities = probabilities + # TODO(QL): implement iter_arrow @staticmethod def _iter_random_indices( @@ -416,7 +582,9 @@ def __init__( drop_last_batch: bool = False, remove_columns: Optional[List[str]] = None, fn_kwargs: Optional[dict] = None, + format_type: Optional[str] = None, ): + super().__init__() self.ex_iterable = ex_iterable self.function = function self.batched = batched @@ -426,8 +594,17 @@ def __init__( self.with_indices = with_indices self.input_columns = input_columns self.fn_kwargs = fn_kwargs or {} + self.format_type = get_format_type_from_alias(format_type) + if format_type == "arrow": + self.iter_arrow = self._iter_arrow def __iter__(self): + if self.format_type == "arrow": + yield from ArrowExamplesIterable(self._iter_arrow, {}) + else: + yield from self._iter() + + def _iter(self): iterator = iter(self.ex_iterable) current_idx = 0 if self.batched: @@ -495,6 +672,40 @@ def __iter__(self): yield key, transformed_example current_idx += 1 + def _iter_arrow(self) -> Iterator[Tuple[Any, pa.Table]]: + ex_iterable = ( + self.ex_iterable if self.ex_iterable.iter_arrow else PythonToArrowExamplesIterable(self.ex_iterable) + ) + iterator = _batch_arrow_tables(ex_iterable.iter_arrow(), batch_size=self.batch_size if self.batched else 1) + current_idx = 0 + for key, pa_table in iterator: + # first build the batch + function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns] + if self.with_indices: + if self.batched: + function_args.append([current_idx + i for i in range(len(pa_table))]) + else: + function_args.append(current_idx) + # then apply the transform + transformed_table = self.function(*function_args, **self.fn_kwargs) + if not isinstance(transformed_table, pa.Table): + raise TypeError( + f"Provided `function` which is applied to pyarrow tables returns a variable of type {type(transformed_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset." + ) + # merge results + merged_output = dict(zip(pa_table.column_names, pa_table.itercolumns())) + merged_output.update(dict(zip(transformed_table.column_names, transformed_table.itercolumns()))) + # then remove the unwanted columns + if self.remove_columns: + for column in self.remove_columns: + if column in merged_output: + del merged_output[column] + # return output + names, arrays = zip(*merged_output.items()) + output_table = pa.Table.from_arrays(arrays=arrays, names=names) + yield key, output_table + current_idx += len(pa_table) + def shuffle_data_sources(self, generator: np.random.Generator) -> "MappedExamplesIterable": """Shuffle the wrapped examples iterable.""" return MappedExamplesIterable( @@ -535,15 +746,26 @@ def __init__( input_columns: Optional[List[str]] = None, batched: bool = False, batch_size: Optional[int] = 1000, + format_type: Optional[str] = None, ): + super().__init__() self.ex_iterable = ex_iterable self.function = function self.batched = batched self.batch_size = batch_size self.with_indices = with_indices self.input_columns = input_columns + self.format_type = get_format_type_from_alias(format_type) + if format_type == "arrow": + self.iter_arrow = self._iter_arrow def __iter__(self): + if self.format_type == "arrow": + yield from ArrowExamplesIterable(self._iter_arrow, {}) + else: + yield from self._iter() + + def _iter(self): iterator = iter(self.ex_iterable) current_idx = 0 if self.batched: @@ -580,6 +802,26 @@ def __iter__(self): yield key, example current_idx += 1 + def _iter_arrow(self): + ex_iterable = ( + self.ex_iterable if self.ex_iterable.iter_arrow else PythonToArrowExamplesIterable(self.ex_iterable) + ) + iterator = _batch_arrow_tables(ex_iterable.iter_arrow(), batch_size=self.batch_size if self.batched else 1) + current_idx = 0 + for key, pa_table in iterator: + # first build the batch + function_args = [pa_table] if self.input_columns is None else [pa_table[col] for col in self.input_columns] + if self.with_indices: + if self.batched: + function_args.append([current_idx + i for i in range(len(pa_table))]) + else: + function_args.append(current_idx) + # then apply the transform + mask = self.function(*function_args) + # yield one example at a time from the batch + yield key, pa_table.filter(mask) + current_idx += len(pa_table) + def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable": """Shuffle the wrapped examples iterable.""" return FilteredExamplesIterable( @@ -609,9 +851,11 @@ def n_shards(self) -> int: class BufferShuffledExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generator: np.random.Generator): + super().__init__() self.ex_iterable = ex_iterable self.buffer_size = buffer_size self.generator = generator + # TODO(QL): implement iter_arrow @staticmethod def _iter_random_indices(rng: np.random.Generator, buffer_size: int, random_batch_size=1000) -> Iterator[int]: @@ -654,8 +898,10 @@ def n_shards(self) -> int: class SkipExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, n: int): + super().__init__() self.ex_iterable = ex_iterable self.n = n + # TODO(QL): implement iter_arrow def __iter__(self): yield from islice(self.ex_iterable, self.n, None) @@ -671,8 +917,10 @@ def n_shards(self) -> int: class TakeExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, n: int): + super().__init__() self.ex_iterable = ex_iterable self.n = n + # TODO(QL): implement iter_arrow def __iter__(self): yield from islice(self.ex_iterable, self.n) @@ -724,9 +972,12 @@ def __init__( features: Features, token_per_repo_id: Dict[str, Union[str, bool, None]], ): + super().__init__() self.ex_iterable = ex_iterable self.features = features self.token_per_repo_id = token_per_repo_id + if self.ex_iterable.iter_arrow is not None: + self.iter_arrow = self._iter_arrow def __iter__(self): # Then for each example, `TypedExamplesIterable` automatically fills missing columns with None. @@ -736,6 +987,19 @@ def __iter__(self): example, self.features, token_per_repo_id=self.token_per_repo_id ) + def _iter_arrow(self) -> Iterator[Tuple[Any, pa.Table]]: + schema = self.features.arrow_schema + for key, pa_table in self.ex_iterable.iter_arrow(): + columns = set(pa_table.column_names) + # add missing columns + for column_name in self.features: + if column_name not in columns: + col = pa.NullArray.from_buffers(pa.null(), len(pa_table), [None]) + pa_table = pa_table.append_column(column_name, col) + if pa_table.schema != schema: + pa_table = cast_table_to_features(pa_table, self.features) + yield key, pa_table + def shuffle_data_sources(self, generator: np.random.Generator) -> "TypedExamplesIterable": """Shuffle the wrapped examples iterable.""" return TypedExamplesIterable( @@ -930,6 +1194,10 @@ def _prepare_ex_iterable_for_iteration(self) -> _BaseExamplesIterable: return ex_iterable def __iter__(self): + if self._format_type == "arrow": + yield from self.iter(batch_size=1) + return + ex_iterable = self._prepare_ex_iterable_for_iteration() if "torch" in sys.modules: @@ -959,7 +1227,24 @@ def iter(self, batch_size: int, drop_last_batch: bool = False): drop_last_batch (:obj:`bool`, default `False`): Whether a last batch smaller than the batch_size should be dropped """ - iterator = iter(self._prepare_ex_iterable_for_iteration()) + ex_iterable = self._prepare_ex_iterable_for_iteration() + if self._format_type == "arrow": + ex_iterable = ex_iterable if ex_iterable.iter_arrow else PythonToArrowExamplesIterable(ex_iterable) + for key, pa_table in _batch_arrow_tables( + ex_iterable.iter_arrow(), batch_size=batch_size, drop_last_batch=drop_last_batch + ): + if self.features: + columns = set(pa_table.colum_names) + # add missing columns + for column_name in self.features: + if column_name not in columns: + col = pa.NullArray.from_buffers(pa.null(), len(pa_table), [None]) + pa_table = pa_table.append_column(column_name, col) + yield cast_table_to_features(pa_table, self.features) + else: + yield pa_table + return + iterator = iter(ex_iterable) for key, example in iterator: # If batched, first build the batch examples = [example] + [example for key, example in islice(iterator, batch_size - 1)] @@ -1154,6 +1439,7 @@ def map( drop_last_batch=drop_last_batch, remove_columns=remove_columns, fn_kwargs=fn_kwargs, + format_type=self._format_type, ) info = self.info.copy() info.features = features @@ -1229,6 +1515,7 @@ def filter( input_columns=input_columns, batched=batched, batch_size=batch_size, + format_type=self._format_type, ) return IterableDataset( ex_iterable=ex_iterable, From a905a19ba3e0bb3aebe8fc6dc6747f6908e9c95e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 2 May 2023 18:49:19 +0200 Subject: [PATCH 02/10] some tests --- tests/test_iterable_dataset.py | 99 ++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 50d1e35de62..9915163a173 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import pytest from datasets import load_dataset @@ -11,6 +12,7 @@ from datasets.formatting import get_format_type_from_alias from datasets.info import DatasetInfo from datasets.iterable_dataset import ( + ArrowExamplesIterable, BufferShuffledExamplesIterable, CyclingMultiSourcesExamplesIterable, ExamplesIterable, @@ -18,11 +20,13 @@ HorizontallyConcatenatedMultiSourcesExamplesIterable, IterableDataset, MappedExamplesIterable, + PythonToArrowExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable, ShufflingConfig, SkipExamplesIterable, TakeExamplesIterable, VerticallyConcatenatedMultiSourcesExamplesIterable, + _batch_arrow_tables, _batch_to_examples, _examples_to_batch, ) @@ -31,6 +35,7 @@ DEFAULT_N_EXAMPLES = 20 +DEFAULT_BATCH_SIZE = 4 DEFAULT_FILEPATH = "file.txt" SAMPLE_DATASET_IDENTIFIER = "lhoestq/test" # has dataset script @@ -47,6 +52,25 @@ def generate_examples_fn(**kwargs): yield f"{filepath}_{i}", {"id": i, **kwargs} +def generate_tables_fn(**kwargs): + kwargs = kwargs.copy() + n = kwargs.pop("n", DEFAULT_N_EXAMPLES) + batch_size = kwargs.pop("batch_size", DEFAULT_BATCH_SIZE) + filepaths = kwargs.pop("filepaths", None) + for filepath in filepaths or [DEFAULT_FILEPATH]: + buffer = [] + batch_idx = 0 + if filepaths is not None: + kwargs["filepath"] = filepath + for i in range(n): + buffer.append({"id": i, **kwargs}) + if len(buffer) == batch_size: + yield f"{filepath}_{batch_idx}", pa.Table.from_pylist(buffer) + buffer = [] + batch_idx += 1 + yield batch_idx, pa.Table.from_pylist(buffer) + + @pytest.fixture def dataset(): ex_iterable = ExamplesIterable(generate_examples_fn, {}) @@ -62,6 +86,39 @@ def dataset_with_several_columns(): return IterableDataset(ex_iterable, info=DatasetInfo(description="dummy"), split="train") +################################ +# +# Utilities tests +# +################################ + + +@pytest.mark.parametrize( + "tables", + [ + [pa.table({"foo": range(10)})], + [pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})], + [pa.table({"foo": [i]}) for i in range(10)], + ], +) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20]) +@pytest.mark.parametrize("drop_last_batch", [False, True]) +def test_batch_arrow_tables(tables, batch_size, drop_last_batch): + full_table = pa.concat_tables(tables) + num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size + num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size + subtables = list(_batch_arrow_tables(tables, batch_size=batch_size, drop_last_batch=drop_last_batch)) + assert len(subtables) == num_batches + if drop_last_batch: + assert all(len(subtable) == batch_size for subtable in subtables) + else: + assert all(len(subtable) == batch_size for subtable in subtables[:-1]) + assert len(subtables[-1]) <= batch_size + if num_rows > 0: + reloaded = pa.concat_tables(subtables) + assert full_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() + + ################################ # # _BaseExampleIterable tests @@ -74,6 +131,7 @@ def test_examples_iterable(): expected = list(generate_examples_fn()) assert next(iter(ex_iterable)) == expected[0] assert list(ex_iterable) == expected + assert ex_iterable.iter_arrow is None def test_examples_iterable_with_kwargs(): @@ -110,6 +168,47 @@ def gen(filepaths, all_metadata): assert filepaths_ids == metadata_ids, "entangled lists of shards/metadata should be shuffled the same way" +def test_arrow_examples_iterable(): + ex_iterable = ArrowExamplesIterable(generate_tables_fn, {}) + expected = sum([pa_table.to_pylist() for _, pa_table in generate_tables_fn()], []) + assert next(iter(ex_iterable))[1] == expected[0] + assert [example for _, example in ex_iterable] == expected + expected = list(generate_tables_fn()) + assert list(ex_iterable.iter_arrow()) == expected + + +def test_arrow_examples_iterable_with_kwargs(): + ex_iterable = ArrowExamplesIterable(generate_tables_fn, {"filepaths": ["0.txt", "1.txt"], "split": "train"}) + expected = sum( + [pa_table.to_pylist() for _, pa_table in generate_tables_fn(filepaths=["0.txt", "1.txt"], split="train")], [] + ) + assert [example for _, example in ex_iterable] == expected + assert all("split" in ex for _, ex in ex_iterable) + assert sorted({ex["filepath"] for _, ex in ex_iterable}) == ["0.txt", "1.txt"] + expected = list(generate_tables_fn(filepaths=["0.txt", "1.txt"], split="train")) + assert list(ex_iterable.iter_arrow()) == expected + + +def test_arrow_examples_iterable_shuffle_data_sources(): + ex_iterable = ArrowExamplesIterable(generate_tables_fn, {"filepaths": ["0.txt", "1.txt"]}) + ex_iterable = ex_iterable.shuffle_data_sources(np.random.default_rng(40)) + expected = sum( + [pa_table.to_pylist() for _, pa_table in generate_tables_fn(filepaths=["1.txt", "0.txt"])], [] + ) # shuffle the filepaths + assert [example for _, example in ex_iterable] == expected + expected = list(generate_tables_fn(filepaths=["1.txt", "0.txt"])) + assert list(ex_iterable.iter_arrow()) == expected + + +def test_python_to_arrow_examples_iterable(): + python_ex_iterable = ExamplesIterable(generate_examples_fn, {}) + arrow_ex_iterable = PythonToArrowExamplesIterable(python_ex_iterable) + assert list(python_ex_iterable) == list(arrow_ex_iterable) + tables = [pa_table for _, pa_table in arrow_ex_iterable.iter_arrow()] + assert 1 < len(tables) < len(list(python_ex_iterable)) + assert pa.Table.from_pylist([example for _, example in python_ex_iterable]) == pa.concat_tables(tables) + + @pytest.mark.parametrize("seed", [42, 1337, 101010, 123456]) def test_buffer_shuffled_examples_iterable(seed): n, buffer_size = 100, 30 From 6c868e10bd7c15de0a4985517a583d3b7684fda6 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 4 May 2023 19:19:50 +0200 Subject: [PATCH 03/10] fix filter --- src/datasets/iterable_dataset.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 9a46e7fa84b..27f80f35a3c 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -431,12 +431,17 @@ class VerticallyConcatenatedMultiSourcesExamplesIterable(_BaseExamplesIterable): def __init__(self, ex_iterables: List[_BaseExamplesIterable]): super().__init__() self.ex_iterables = ex_iterables - # TODO(QL): implement iter_arrow + if all(ex_iterable.iter_arrow is not None for ex_iterable in ex_iterables): + self.iter_arrow = self._iter_arrow def __iter__(self): for ex_iterable in self.ex_iterables: yield from ex_iterable + def _iter_arrow(self): + for ex_iterable in self.ex_iterables: + yield from ex_iterable.iter_arrow() + def shuffle_data_sources( self, generator: np.random.Generator ) -> "VerticallyConcatenatedMultiSourcesExamplesIterable": @@ -818,8 +823,11 @@ def _iter_arrow(self): function_args.append(current_idx) # then apply the transform mask = self.function(*function_args) - # yield one example at a time from the batch - yield key, pa_table.filter(mask) + # yield the filtered table + if self.batched: + yield key, pa_table.filter(mask) + elif mask.as_py() if isinstance(mask, pa.BooleanScalar) else mask: + yield key, pa_table current_idx += len(pa_table) def shuffle_data_sources(self, seed: Optional[int]) -> "FilteredExamplesIterable": From f8417a41547ce0c939bd342398be621f5ce3e340 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 4 May 2023 19:19:56 +0200 Subject: [PATCH 04/10] add test --- tests/test_iterable_dataset.py | 64 ++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 9915163a173..cc17807443a 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -22,10 +22,16 @@ MappedExamplesIterable, PythonToArrowExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable, + SelectColumnsIterable, + ShuffledDataSourcesArrowExamplesIterable, + ShuffledDataSourcesExamplesIterable, ShufflingConfig, SkipExamplesIterable, + StepExamplesIterable, TakeExamplesIterable, + TypedExamplesIterable, VerticallyConcatenatedMultiSourcesExamplesIterable, + _BaseExamplesIterable, _batch_arrow_tables, _batch_to_examples, _examples_to_batch, @@ -702,6 +708,64 @@ def test_horizontally_concatenated_examples_iterable(): ), "horizontally concatenated examples makes the shards order fixed" +@pytest.mark.parametrize( + "ex_iterable", + [ + ExamplesIterable(generate_examples_fn, {}), + ShuffledDataSourcesExamplesIterable(generate_examples_fn, {}, np.random.default_rng(42)), + SelectColumnsIterable(ExamplesIterable(generate_examples_fn, {}), ["id"]), + StepExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 2, 0), + CyclingMultiSourcesExamplesIterable([ExamplesIterable(generate_examples_fn, {})]), + VerticallyConcatenatedMultiSourcesExamplesIterable([ExamplesIterable(generate_examples_fn, {})]), + HorizontallyConcatenatedMultiSourcesExamplesIterable([ExamplesIterable(generate_examples_fn, {})]), + RandomlyCyclingMultiSourcesExamplesIterable( + [ExamplesIterable(generate_examples_fn, {})], np.random.default_rng(42) + ), + MappedExamplesIterable(ExamplesIterable(generate_examples_fn, {}), lambda x: x), + MappedExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda x: x), + FilteredExamplesIterable(ExamplesIterable(generate_examples_fn, {}), lambda x: True), + FilteredExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda x: True), + BufferShuffledExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10, np.random.default_rng(42)), + SkipExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10), + TakeExamplesIterable(ExamplesIterable(generate_examples_fn, {}), 10), + TypedExamplesIterable( + ExamplesIterable(generate_examples_fn, {}), Features({"id": Value("int32")}), token_per_repo_id={} + ), + ], +) +def test_no_iter_arrow(ex_iterable: _BaseExamplesIterable): + assert ex_iterable.iter_arrow is None + + +@pytest.mark.parametrize( + "ex_iterable", + [ + ArrowExamplesIterable(generate_tables_fn, {}), + ShuffledDataSourcesArrowExamplesIterable(generate_tables_fn, {}, np.random.default_rng(42)), + SelectColumnsIterable(ArrowExamplesIterable(generate_tables_fn, {}), ["id"]), + # StepExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), 2, 0), # not implemented + # CyclingMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})]), # not implemented + VerticallyConcatenatedMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})]), + # HorizontallyConcatenatedMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})]), # not implemented + # RandomlyCyclingMultiSourcesExamplesIterable([ArrowExamplesIterable(generate_tables_fn, {})], np.random.default_rng(42)), # not implemented + MappedExamplesIterable(ExamplesIterable(generate_examples_fn, {}), lambda t: t, format_type="arrow"), + MappedExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda t: t, format_type="arrow"), + FilteredExamplesIterable(ExamplesIterable(generate_examples_fn, {}), lambda t: True, format_type="arrow"), + FilteredExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), lambda t: True, format_type="arrow"), + # BufferShuffledExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), 10, np.random.default_rng(42)), # not implemented + # SkipExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), 10), # not implemented + # TakeExamplesIterable(ArrowExamplesIterable(generate_tables_fn, {}), 10), # not implemented + TypedExamplesIterable( + ArrowExamplesIterable(generate_tables_fn, {}), Features({"id": Value("int32")}), token_per_repo_id={} + ), + ], +) +def test_iter_arrow(ex_iterable: _BaseExamplesIterable): + assert ex_iterable.iter_arrow is not None + key, pa_table = next(ex_iterable.iter_arrow()) + assert isinstance(pa_table, pa.Table) + + ############################ # # IterableDataset tests From 95457f2aebdc35a6e460d0afb1212f064ea56c79 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 5 May 2023 11:30:37 +0200 Subject: [PATCH 05/10] fix test --- tests/test_iterable_dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index cc17807443a..4bd529c1977 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -113,15 +113,15 @@ def test_batch_arrow_tables(tables, batch_size, drop_last_batch): full_table = pa.concat_tables(tables) num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size - subtables = list(_batch_arrow_tables(tables, batch_size=batch_size, drop_last_batch=drop_last_batch)) + subtables = list(_batch_arrow_tables([(i, table) for i, table in enumerate(tables)], batch_size=batch_size, drop_last_batch=drop_last_batch)) assert len(subtables) == num_batches if drop_last_batch: - assert all(len(subtable) == batch_size for subtable in subtables) + assert all(len(subtable) == batch_size for _, subtable in subtables) else: - assert all(len(subtable) == batch_size for subtable in subtables[:-1]) - assert len(subtables[-1]) <= batch_size + assert all(len(subtable) == batch_size for _, subtable in subtables[:-1]) + assert len(subtables[-1][1]) <= batch_size if num_rows > 0: - reloaded = pa.concat_tables(subtables) + reloaded = pa.concat_tables([subtable for _, subtable in subtables]) assert full_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() From 8f019dffffb214b44b30dd9ac56fdea12259e148 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 9 May 2023 20:14:24 +0200 Subject: [PATCH 06/10] tests and fixes --- src/datasets/iterable_dataset.py | 153 ++++++++-------- tests/test_iterable_dataset.py | 291 +++++++++++++++++++++++++++++-- 2 files changed, 365 insertions(+), 79 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 27f80f35a3c..e746f0ad305 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -5,7 +5,7 @@ from copy import deepcopy from dataclasses import dataclass from itertools import cycle, islice -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union import numpy as np import pyarrow as pa @@ -25,6 +25,8 @@ logger = get_logger(__name__) +Key = Union[int, str] + def _infer_features_from_batch(batch: Dict[str, list], try_features: Optional[Features] = None) -> Features: pa_table = pa.Table.from_pydict(batch) @@ -81,23 +83,57 @@ def hasnext(self): return self._hasnext -def _batch_arrow_tables( - iterator: Iterator[Tuple[Any, pa.Table]], +def _convert_to_arrow( + iterable: Iterable[Tuple[Key, dict]], batch_size: int, drop_last_batch=False, -) -> Iterator[Tuple[Any, pa.Table]]: +) -> Iterator[Tuple[Key, pa.Table]]: + """Iterate over sub-tables of size `batch_size`. + + Args: + iterable (`Iterable[Tuple[Key, dict]]`): + An examples iterable containing tuples (example_key, example) of type (int/str, dict) + batch_size (`Optional[int]`): + Size of each sub-table to yield. If None or <= 0, yields the full table. + drop_last_batch (`bool`, defaults to `False`): + Drop the last batch if it is smaller than `batch_size`. + """ + if batch_size is None or batch_size <= 0: + yield "all", pa.Table.from_pylist([example for _, example in iterable]) + return + iterator = iter(iterable) + for key, example in iterator: + iterator_batch = islice(iterator, batch_size - 1) + key_examples_list = [(key, example)] + [(key, example) for key, example in iterator_batch] + if len(key_examples_list) < batch_size and drop_last_batch: + return + keys, examples = zip(*key_examples_list) + new_key = "_".join(str(key) for key in keys) + yield new_key, pa.Table.from_pylist(examples) + + +def _batch_arrow_tables( + iterable: Iterable[Tuple[Key, pa.Table]], + batch_size: Optional[int], + drop_last_batch=False, +) -> Iterator[Tuple[Key, pa.Table]]: """Iterate over sub-tables of size `batch_size`. Args: - batch_size (`int`): - Size of each sub-table to yield. + iterable (`Iterable[Tuple[Key, pa.Table]]`): + A tables iterable containing tuples (table_key, table) of type (int/str, pa.Table) + batch_size (`Optional[int]`): + Size of each sub-table to yield. If None or <= 0, yields the full table. drop_last_batch (`bool`, defaults to `False`): Drop the last batch if it is smaller than `batch_size`. """ + if batch_size is None or batch_size <= 0: + yield "all", pa.concat_tables([pa_table for _, pa_table in iterable]) + return keys_buffer = [] chunks_buffer = [] chunks_buffer_size = 0 - for key, pa_table in iterator: + for key, pa_table in iterable: for chunk in pa_table.to_reader(max_chunksize=batch_size): if len(chunk) == 0: continue @@ -132,9 +168,9 @@ class _BaseExamplesIterable: """Base class for the examples iterable used by an IterableDataset""" def __init__(self) -> None: - self.iter_arrow: Optional[Callable[[], Iterator[Tuple[Any, pa.Table]]]] = None + self.iter_arrow: Optional[Callable[[], Iterator[Tuple[Key, pa.Table]]]] = None - def __iter__(self): + def __iter__(self) -> Iterator[Tuple[Key, dict]]: """An examples iterable should yield tuples (example_key, example) of type (int/str, dict)""" raise NotImplementedError(f"{type(self)} doesn't implement __iter__ yet") @@ -155,7 +191,7 @@ def n_shards(self) -> int: class ExamplesIterable(_BaseExamplesIterable): - def __init__(self, generate_examples_fn: Callable, kwargs: dict): + def __init__(self, generate_examples_fn: Callable[..., Tuple[Key, dict]], kwargs: dict): super().__init__() self.generate_examples_fn = generate_examples_fn self.kwargs = kwargs @@ -178,7 +214,9 @@ def n_shards(self) -> int: class ShuffledDataSourcesExamplesIterable(ExamplesIterable): - def __init__(self, generate_examples_fn: Callable, kwargs: dict, generator: np.random.Generator): + def __init__( + self, generate_examples_fn: Callable[..., Tuple[Key, dict]], kwargs: dict, generator: np.random.Generator + ): super().__init__(generate_examples_fn, kwargs) self.generator = deepcopy(generator) @@ -198,7 +236,7 @@ def shard_data_sources(self, shard_indices: List[int]) -> "ExamplesIterable": class ArrowExamplesIterable(_BaseExamplesIterable): - def __init__(self, generate_tables_fn: Callable, kwargs: dict): + def __init__(self, generate_tables_fn: Callable[..., Tuple[Key, pa.Table]], kwargs: dict): super().__init__() self.generate_tables_fn = generate_tables_fn self.kwargs = kwargs @@ -234,7 +272,7 @@ def n_shards(self) -> int: class ShuffledDataSourcesArrowExamplesIterable(ArrowExamplesIterable): def __init__( self, - generate_tables_fn: Callable, + generate_tables_fn: Callable[..., Tuple[Key, pa.Table]], kwargs: dict, generator: np.random.Generator, ): @@ -266,36 +304,6 @@ def shard_data_sources(self, shard_indices: List[int]) -> "ExamplesIterable": ) -class PythonToArrowExamplesIterable(_BaseExamplesIterable): - def __init__(self, ex_iterable: _BaseExamplesIterable): - super().__init__() - self.ex_iterable = ex_iterable - self.iter_arrow = self._iter_arrow - - def __iter__(self): - yield from self.ex_iterable - - def _iter_arrow(self) -> Iterator[Tuple[Any, pa.Table]]: - iterator = iter(self.ex_iterable) - batch_size = config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER - for key, example in iterator: - iterator_batch = islice(iterator, batch_size - 1) - key_examples_list = [(key, example)] + [(key, example) for key, example in iterator_batch] - keys, examples = zip(*key_examples_list) - new_key = "_".join(str(key) for key in keys) - yield new_key, pa.Table.from_pylist(examples) - - def shuffle_data_sources(self, generator: np.random.Generator) -> "PythonToArrowExamplesIterable": - return PythonToArrowExamplesIterable(self.ex_iterable.shuffle_data_sources(generator)) - - def shard_data_sources(self, shard_indices: List[int]) -> "PythonToArrowExamplesIterable": - return PythonToArrowExamplesIterable(self.ex_iterable.shard_data_sources(shard_indices)) - - @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards - - class SelectColumnsIterable(_BaseExamplesIterable): def __init__(self, ex_iterable: _BaseExamplesIterable, column_names: List[str]): super().__init__() @@ -308,7 +316,7 @@ def __iter__(self): for idx, row in self.ex_iterable: yield idx, {c: row[c] for c in self.column_names} - def _iter_arrow(self) -> Iterator[Tuple[Any, pa.Table]]: + def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: for idx, pa_table in self.ex_iterable.iter_arrow(): yield idx, pa_table.select(self.column_names) @@ -677,11 +685,19 @@ def _iter(self): yield key, transformed_example current_idx += 1 - def _iter_arrow(self) -> Iterator[Tuple[Any, pa.Table]]: - ex_iterable = ( - self.ex_iterable if self.ex_iterable.iter_arrow else PythonToArrowExamplesIterable(self.ex_iterable) - ) - iterator = _batch_arrow_tables(ex_iterable.iter_arrow(), batch_size=self.batch_size if self.batched else 1) + def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: + if self.ex_iterable.iter_arrow: + iterator = _batch_arrow_tables( + self.ex_iterable.iter_arrow(), + batch_size=self.batch_size if self.batched else 1, + drop_last_batch=self.drop_last_batch, + ) + else: + iterator = _convert_to_arrow( + self.ex_iterable, + batch_size=self.batch_size if self.batched else 1, + drop_last_batch=self.drop_last_batch, + ) current_idx = 0 for key, pa_table in iterator: # first build the batch @@ -692,22 +708,18 @@ def _iter_arrow(self) -> Iterator[Tuple[Any, pa.Table]]: else: function_args.append(current_idx) # then apply the transform - transformed_table = self.function(*function_args, **self.fn_kwargs) - if not isinstance(transformed_table, pa.Table): + output_table = self.function(*function_args, **self.fn_kwargs) + if not isinstance(output_table, pa.Table): raise TypeError( - f"Provided `function` which is applied to pyarrow tables returns a variable of type {type(transformed_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset." + f"Provided `function` which is applied to pyarrow tables returns a variable of type {type(output_table)}. Make sure provided `function` returns a a pyarrow table to update the dataset." ) - # merge results - merged_output = dict(zip(pa_table.column_names, pa_table.itercolumns())) - merged_output.update(dict(zip(transformed_table.column_names, transformed_table.itercolumns()))) + # we don't need to merge results for consistency with Dataset.map which merges iif both input and output are dicts # then remove the unwanted columns if self.remove_columns: for column in self.remove_columns: - if column in merged_output: - del merged_output[column] + if column in output_table.column_names: + output_table = output_table.remove_column(output_table.column_names.index(column)) # return output - names, arrays = zip(*merged_output.items()) - output_table = pa.Table.from_arrays(arrays=arrays, names=names) yield key, output_table current_idx += len(pa_table) @@ -808,10 +820,12 @@ def _iter(self): current_idx += 1 def _iter_arrow(self): - ex_iterable = ( - self.ex_iterable if self.ex_iterable.iter_arrow else PythonToArrowExamplesIterable(self.ex_iterable) - ) - iterator = _batch_arrow_tables(ex_iterable.iter_arrow(), batch_size=self.batch_size if self.batched else 1) + if self.ex_iterable.iter_arrow: + iterator = _batch_arrow_tables( + self.ex_iterable.iter_arrow(), batch_size=self.batch_size if self.batched else 1 + ) + else: + iterator = _convert_to_arrow(self.ex_iterable, batch_size=self.batch_size if self.batched else 1) current_idx = 0 for key, pa_table in iterator: # first build the batch @@ -995,7 +1009,7 @@ def __iter__(self): example, self.features, token_per_repo_id=self.token_per_repo_id ) - def _iter_arrow(self) -> Iterator[Tuple[Any, pa.Table]]: + def _iter_arrow(self) -> Iterator[Tuple[Key, pa.Table]]: schema = self.features.arrow_schema for key, pa_table in self.ex_iterable.iter_arrow(): columns = set(pa_table.column_names) @@ -1237,10 +1251,13 @@ def iter(self, batch_size: int, drop_last_batch: bool = False): """ ex_iterable = self._prepare_ex_iterable_for_iteration() if self._format_type == "arrow": - ex_iterable = ex_iterable if ex_iterable.iter_arrow else PythonToArrowExamplesIterable(ex_iterable) - for key, pa_table in _batch_arrow_tables( - ex_iterable.iter_arrow(), batch_size=batch_size, drop_last_batch=drop_last_batch - ): + if ex_iterable.iter_arrow: + iterator = _batch_arrow_tables( + ex_iterable.iter_arrow(), batch_size=batch_size, drop_last_batch=drop_last_batch + ) + else: + iterator = _convert_to_arrow(ex_iterable, batch_size=batch_size, drop_last_batch=drop_last_batch) + for key, pa_table in iterator: if self.features: columns = set(pa_table.colum_names) # add missing columns diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 4bd529c1977..5bbec98af14 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd import pyarrow as pa +import pyarrow.compute as pc import pytest from datasets import load_dataset @@ -20,7 +21,6 @@ HorizontallyConcatenatedMultiSourcesExamplesIterable, IterableDataset, MappedExamplesIterable, - PythonToArrowExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable, SelectColumnsIterable, ShuffledDataSourcesArrowExamplesIterable, @@ -34,6 +34,7 @@ _BaseExamplesIterable, _batch_arrow_tables, _batch_to_examples, + _convert_to_arrow, _examples_to_batch, ) @@ -99,6 +100,31 @@ def dataset_with_several_columns(): ################################ +@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20]) +@pytest.mark.parametrize("drop_last_batch", [False, True]) +def test_convert_to_arrow(batch_size, drop_last_batch): + examples = [{"foo": i} for i in range(10)] + full_table = pa.Table.from_pylist(examples) + num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size + num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size + subtables = list( + _convert_to_arrow( + [(i, example) for i, example in enumerate(examples)], + batch_size=batch_size, + drop_last_batch=drop_last_batch, + ) + ) + assert len(subtables) == num_batches + if drop_last_batch: + assert all(len(subtable) == batch_size for _, subtable in subtables) + else: + assert all(len(subtable) == batch_size for _, subtable in subtables[:-1]) + assert len(subtables[-1][1]) <= batch_size + if num_rows > 0: + reloaded = pa.concat_tables([subtable for _, subtable in subtables]) + assert full_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() + + @pytest.mark.parametrize( "tables", [ @@ -113,7 +139,11 @@ def test_batch_arrow_tables(tables, batch_size, drop_last_batch): full_table = pa.concat_tables(tables) num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size - subtables = list(_batch_arrow_tables([(i, table) for i, table in enumerate(tables)], batch_size=batch_size, drop_last_batch=drop_last_batch)) + subtables = list( + _batch_arrow_tables( + [(i, table) for i, table in enumerate(tables)], batch_size=batch_size, drop_last_batch=drop_last_batch + ) + ) assert len(subtables) == num_batches if drop_last_batch: assert all(len(subtable) == batch_size for _, subtable in subtables) @@ -206,15 +236,6 @@ def test_arrow_examples_iterable_shuffle_data_sources(): assert list(ex_iterable.iter_arrow()) == expected -def test_python_to_arrow_examples_iterable(): - python_ex_iterable = ExamplesIterable(generate_examples_fn, {}) - arrow_ex_iterable = PythonToArrowExamplesIterable(python_ex_iterable) - assert list(python_ex_iterable) == list(arrow_ex_iterable) - tables = [pa_table for _, pa_table in arrow_ex_iterable.iter_arrow()] - assert 1 < len(tables) < len(list(python_ex_iterable)) - assert pa.Table.from_pylist([example for _, example in python_ex_iterable]) == pa.concat_tables(tables) - - @pytest.mark.parametrize("seed", [42, 1337, 101010, 123456]) def test_buffer_shuffled_examples_iterable(seed): n, buffer_size = 100, 30 @@ -540,6 +561,254 @@ def test_mapped_examples_iterable_input_columns(n, func, batched, batch_size, in assert [x for _, x in ex_iterable] == expected +@pytest.mark.parametrize( + "n, func, batched, batch_size", + [ + (3, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), False, None), # just add 1 to the id + (3, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 1), # same with bs=1 + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10), # same with bs=10 + (25, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10), # same with bs=10 + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, None), # same with bs=None + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, -1), # same with bs<=0 + (3, lambda t: pa.concat_tables([t] * 2), True, 1), # make a duplicate of each example + ], +) +def test_mapped_examples_iterable_arrow_format(n, func, batched, batch_size): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, func, batched=batched, batch_size=batch_size, format_type="arrow" + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if batched is False: + expected = [func(pa.Table.from_pylist([x])).to_pylist()[0] for x in all_examples] + else: + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend(func(batch).to_pylist()) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + + +@pytest.mark.parametrize( + "n, func, batched, batch_size", + [ + (3, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), False, None), # just add 1 to the id + (3, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 1), # same with bs=1 + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10), # same with bs=10 + (25, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10), # same with bs=10 + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, None), # same with bs=None + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, -1), # same with bs<=0 + (3, lambda t: pa.concat_tables([t] * 2), True, 1), # make a duplicate of each example + ], +) +def test_mapped_examples_iterable_drop_last_batch_and_arrow_format(n, func, batched, batch_size): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, func, batched=batched, batch_size=batch_size, drop_last_batch=True, format_type="arrow" + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + is_empty = False + if batched is False: + # `drop_last_batch` has no effect here + expected = [func(pa.Table.from_pylist([x])).to_pylist()[0] for x in all_examples] + else: + all_transformed_examples = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + if len(examples) < batch_size: # ignore last batch + break + batch = pa.Table.from_pylist(examples) + out = func(batch) + all_transformed_examples.extend( + out.to_pylist() + ) # we don't merge with input since they're arrow tables and not dictionaries + all_examples = all_examples if n % batch_size == 0 else all_examples[: n // batch_size * batch_size] + if all_examples: + expected = all_transformed_examples + else: + is_empty = True + + if not is_empty: + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + else: + with pytest.raises(StopIteration): + next(iter(ex_iterable)) + + +@pytest.mark.parametrize( + "n, func, batched, batch_size", + [ + ( + 3, + lambda t, index: t.append_column("id+idx", pc.add(t["id"], index)), + False, + None, + ), # add the index to the id + ( + 25, + lambda t, indices: t.append_column("id+idx", pc.add(t["id"], indices)), + True, + 10, + ), # add the index to the id + (5, lambda t, indices: t.append_column("id+idx", pc.add(t["id"], indices)), True, None), # same with bs=None + (5, lambda t, indices: t.append_column("id+idx", pc.add(t["id"], indices)), True, -1), # same with bs<=0 + ], +) +def test_mapped_examples_iterable_with_indices_and_arrow_format(n, func, batched, batch_size): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, func, batched=batched, batch_size=batch_size, with_indices=True, format_type="arrow" + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if batched is False: + expected = [func(pa.Table.from_pylist([x]), i).to_pylist()[0] for i, x in enumerate(all_examples)] + else: + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend(func(batch, list(range(batch_offset, batch_offset + len(batch)))).to_pylist()) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + + +@pytest.mark.parametrize( + "n, func, batched, batch_size, remove_columns", + [ + ( + 3, + lambda t: t.append_column("id+1", pc.add(t["id"], 1)), + False, + None, + ["extra_column"], + ), # just add 1 to the id + (25, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, 10, ["extra_column"]), # same with bs=10 + ( + 50, + lambda t: pa.table({"foo": ["bar"] * np.random.default_rng(t["id"][0].as_py()).integers(0, 10)}), + True, + 8, + ["extra_column", "id"], + ), # make a duplicate of each example + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, None, ["extra_column"]), # same with bs=None + (5, lambda t: t.append_column("id+1", pc.add(t["id"], 1)), True, -1, ["extra_column"]), # same with bs<=0 + ], +) +def test_mapped_examples_iterable_remove_columns_arrow_format(n, func, batched, batch_size, remove_columns): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "extra_column": "foo"}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, + func, + batched=batched, + batch_size=batch_size, + remove_columns=remove_columns, + format_type="arrow", + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + columns_to_remove = remove_columns if isinstance(remove_columns, list) else [remove_columns] + if batched is False: + expected = [ + {**{k: v for k, v in func(pa.Table.from_pylist([x])).to_pylist()[0].items() if k not in columns_to_remove}} + for x in all_examples + ] + else: + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend( + [{k: v for k, v in x.items() if k not in columns_to_remove} for x in func(batch).to_pylist()] + ) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + + +@pytest.mark.parametrize( + "n, func, batched, batch_size, fn_kwargs", + [ + (3, lambda t, y=0: t.append_column("id+idx", pc.add(t["id"], y)), False, None, None), + (3, lambda t, y=0: t.append_column("id+idx", pc.add(t["id"], y)), False, None, {"y": 3}), + (25, lambda t, y=0: t.append_column("id+idx", pc.add(t["id"], y)), True, 10, {"y": 3}), + (5, lambda t, y=0: t.append_column("id+idx", pc.add(t["id"], y)), True, None, {"y": 3}), # same with bs=None + (5, lambda t, y=0: t.append_column("id+idx", pc.add(t["id"], y)), True, -1, {"y": 3}), # same with bs<=0 + ], +) +def test_mapped_examples_iterable_fn_kwargs_and_arrow_format(n, func, batched, batch_size, fn_kwargs): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, func, batched=batched, batch_size=batch_size, fn_kwargs=fn_kwargs, format_type="arrow" + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if fn_kwargs is None: + fn_kwargs = {} + if batched is False: + expected = [func(pa.Table.from_pylist([x]), **fn_kwargs).to_pylist()[0] for x in all_examples] + else: + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend(func(batch, **fn_kwargs).to_pylist()) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + + +@pytest.mark.parametrize( + "n, func, batched, batch_size, input_columns", + [ + (3, lambda id_: pa.table({"id+1": pc.add(id_, 1)}), False, None, ["id"]), # just add 1 to the id + (25, lambda ids_: pa.table({"id+1": pc.add(ids_, 1)}), True, 10, ["id"]), # same with bs=10 + (5, lambda ids_: pa.table({"id+1": pc.add(ids_, 1)}), True, None, ["id"]), # same with bs=None + (5, lambda ids_: pa.table({"id+1": pc.add(ids_, 1)}), True, -1, ["id"]), # same with bs<=0 + ], +) +def test_mapped_examples_iterable_input_columns_and_arrow_format(n, func, batched, batch_size, input_columns): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = MappedExamplesIterable( + base_ex_iterable, + func, + batched=batched, + batch_size=batch_size, + input_columns=input_columns, + format_type="arrow", + ) + all_examples = [x for _, x in generate_examples_fn(n=n)] + columns_to_input = input_columns if isinstance(input_columns, list) else [input_columns] + if batched is False: + expected = [ + func(*[pa.Table.from_pylist([x])[col] for col in columns_to_input]).to_pylist()[0] for x in all_examples + ] + else: + expected = [] + # If batch_size is None or <=0, we use the whole dataset as a single batch + if batch_size is None or batch_size <= 0: + batch_size = len(all_examples) + for batch_offset in range(0, len(all_examples), batch_size): + examples = all_examples[batch_offset : batch_offset + batch_size] + batch = pa.Table.from_pylist(examples) + expected.extend(func(*[batch[col] for col in columns_to_input]).to_pylist()) + assert next(iter(ex_iterable))[1] == expected[0] + assert [x for _, x in ex_iterable] == expected + + @pytest.mark.parametrize( "n, func, batched, batch_size", [ From 00b148b09da2074fcaba0538a23c7f46d28d387c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 10 May 2023 14:03:55 +0200 Subject: [PATCH 07/10] use ArrowExamplesIterable in ArrowBasedBuilder.as_streaming_dataset --- src/datasets/builder.py | 6 ++---- src/datasets/iterable_dataset.py | 11 ----------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index eb57d0c66b0..ff00178e97c 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -56,7 +56,7 @@ from .filesystems import is_remote_filesystem from .fingerprint import Hasher from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo -from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper +from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset from .keyhash import DuplicatedKeysError from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase from .splits import Split, SplitDict, SplitGenerator, SplitInfo @@ -1895,9 +1895,7 @@ def _prepare_split_single( yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: - return ExamplesIterable( - _generate_examples_from_tables_wrapper(self._generate_tables), kwargs=split_generator.gen_kwargs - ) + return ArrowExamplesIterable(self._generate_tables, kwargs=split_generator.gen_kwargs) class MissingBeamOptions(ValueError): diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index e746f0ad305..17a153fbe23 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1043,17 +1043,6 @@ def n_shards(self) -> int: return self.ex_iterable.n_shards -def _generate_examples_from_tables_wrapper(generate_tables_fn): - def wrapper(**kwargs): - python_formatter = PythonFormatter() - for key, table in generate_tables_fn(**kwargs): - batch = python_formatter.format_batch(table) - for i, example in enumerate(_batch_to_examples(batch)): - yield f"{key}_{i}", example - - return wrapper - - @dataclass class ShufflingConfig: generator: np.random.Generator From bd373f69f12e926f4e2a489c14df36c38ce07bcc Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 24 May 2023 17:21:02 +0200 Subject: [PATCH 08/10] missing fn_kwargs in filter --- src/datasets/iterable_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index fbb51dff7c3..e7185bf1f55 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -864,7 +864,7 @@ def _iter_arrow(self): else: function_args.append(current_idx) # then apply the transform - mask = self.function(*function_args) + mask = self.function(*function_args, **self.fn_kwargs) # yield the filtered table if self.batched: yield key, pa_table.filter(mask) From f2778e1ab255545cb2171379fd2276c85768a2ad Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 26 May 2023 20:03:40 +0200 Subject: [PATCH 09/10] albert's comments --- src/datasets/iterable_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index e7185bf1f55..c78b7e449a6 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -87,9 +87,9 @@ def hasnext(self): def _convert_to_arrow( iterable: Iterable[Tuple[Key, dict]], batch_size: int, - drop_last_batch=False, + drop_last_batch: bool = False, ) -> Iterator[Tuple[Key, pa.Table]]: - """Iterate over sub-tables of size `batch_size`. + """Convert and group examples in Arrow tables of size `batch_size`. Args: iterable (`Iterable[Tuple[Key, dict]]`): @@ -116,7 +116,7 @@ def _convert_to_arrow( def _batch_arrow_tables( iterable: Iterable[Tuple[Key, pa.Table]], batch_size: Optional[int], - drop_last_batch=False, + drop_last_batch: bool = False, ) -> Iterator[Tuple[Key, pa.Table]]: """Iterate over sub-tables of size `batch_size`. From 028822a5d657f6c1251f61b56a701c4d7d2ab0a7 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 30 May 2023 18:11:04 +0200 Subject: [PATCH 10/10] albert's comment: update docstring --- src/datasets/iterable_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index c78b7e449a6..a408fe4f4b8 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1420,7 +1420,8 @@ def with_format( ) -> "IterableDataset": """ Return a dataset with the specified format. - This method only supports the "torch" format for now. + Supported formats: "arrow", or None for regular python objects. + The other formats are currently not implemented. Args: