From a905a19ba3e0bb3aebe8fc6dc6747f6908e9c95e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 2 May 2023 18:49:19 +0200 Subject: [PATCH] some tests --- tests/test_iterable_dataset.py | 99 ++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 50d1e35de62..9915163a173 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import pytest from datasets import load_dataset @@ -11,6 +12,7 @@ from datasets.formatting import get_format_type_from_alias from datasets.info import DatasetInfo from datasets.iterable_dataset import ( + ArrowExamplesIterable, BufferShuffledExamplesIterable, CyclingMultiSourcesExamplesIterable, ExamplesIterable, @@ -18,11 +20,13 @@ HorizontallyConcatenatedMultiSourcesExamplesIterable, IterableDataset, MappedExamplesIterable, + PythonToArrowExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable, ShufflingConfig, SkipExamplesIterable, TakeExamplesIterable, VerticallyConcatenatedMultiSourcesExamplesIterable, + _batch_arrow_tables, _batch_to_examples, _examples_to_batch, ) @@ -31,6 +35,7 @@ DEFAULT_N_EXAMPLES = 20 +DEFAULT_BATCH_SIZE = 4 DEFAULT_FILEPATH = "file.txt" SAMPLE_DATASET_IDENTIFIER = "lhoestq/test" # has dataset script @@ -47,6 +52,25 @@ def generate_examples_fn(**kwargs): yield f"{filepath}_{i}", {"id": i, **kwargs} +def generate_tables_fn(**kwargs): + kwargs = kwargs.copy() + n = kwargs.pop("n", DEFAULT_N_EXAMPLES) + batch_size = kwargs.pop("batch_size", DEFAULT_BATCH_SIZE) + filepaths = kwargs.pop("filepaths", None) + for filepath in filepaths or [DEFAULT_FILEPATH]: + buffer = [] + batch_idx = 0 + if filepaths is not None: + kwargs["filepath"] = filepath + for i in range(n): + buffer.append({"id": i, **kwargs}) + if len(buffer) == batch_size: + yield f"{filepath}_{batch_idx}", pa.Table.from_pylist(buffer) + buffer = [] + batch_idx += 1 + yield batch_idx, pa.Table.from_pylist(buffer) + + @pytest.fixture def dataset(): ex_iterable = ExamplesIterable(generate_examples_fn, {}) @@ -62,6 +86,39 @@ def dataset_with_several_columns(): return IterableDataset(ex_iterable, info=DatasetInfo(description="dummy"), split="train") +################################ +# +# Utilities tests +# +################################ + + +@pytest.mark.parametrize( + "tables", + [ + [pa.table({"foo": range(10)})], + [pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})], + [pa.table({"foo": [i]}) for i in range(10)], + ], +) +@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20]) +@pytest.mark.parametrize("drop_last_batch", [False, True]) +def test_batch_arrow_tables(tables, batch_size, drop_last_batch): + full_table = pa.concat_tables(tables) + num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size + num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size + subtables = list(_batch_arrow_tables(tables, batch_size=batch_size, drop_last_batch=drop_last_batch)) + assert len(subtables) == num_batches + if drop_last_batch: + assert all(len(subtable) == batch_size for subtable in subtables) + else: + assert all(len(subtable) == batch_size for subtable in subtables[:-1]) + assert len(subtables[-1]) <= batch_size + if num_rows > 0: + reloaded = pa.concat_tables(subtables) + assert full_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict() + + ################################ # # _BaseExampleIterable tests @@ -74,6 +131,7 @@ def test_examples_iterable(): expected = list(generate_examples_fn()) assert next(iter(ex_iterable)) == expected[0] assert list(ex_iterable) == expected + assert ex_iterable.iter_arrow is None def test_examples_iterable_with_kwargs(): @@ -110,6 +168,47 @@ def gen(filepaths, all_metadata): assert filepaths_ids == metadata_ids, "entangled lists of shards/metadata should be shuffled the same way" +def test_arrow_examples_iterable(): + ex_iterable = ArrowExamplesIterable(generate_tables_fn, {}) + expected = sum([pa_table.to_pylist() for _, pa_table in generate_tables_fn()], []) + assert next(iter(ex_iterable))[1] == expected[0] + assert [example for _, example in ex_iterable] == expected + expected = list(generate_tables_fn()) + assert list(ex_iterable.iter_arrow()) == expected + + +def test_arrow_examples_iterable_with_kwargs(): + ex_iterable = ArrowExamplesIterable(generate_tables_fn, {"filepaths": ["0.txt", "1.txt"], "split": "train"}) + expected = sum( + [pa_table.to_pylist() for _, pa_table in generate_tables_fn(filepaths=["0.txt", "1.txt"], split="train")], [] + ) + assert [example for _, example in ex_iterable] == expected + assert all("split" in ex for _, ex in ex_iterable) + assert sorted({ex["filepath"] for _, ex in ex_iterable}) == ["0.txt", "1.txt"] + expected = list(generate_tables_fn(filepaths=["0.txt", "1.txt"], split="train")) + assert list(ex_iterable.iter_arrow()) == expected + + +def test_arrow_examples_iterable_shuffle_data_sources(): + ex_iterable = ArrowExamplesIterable(generate_tables_fn, {"filepaths": ["0.txt", "1.txt"]}) + ex_iterable = ex_iterable.shuffle_data_sources(np.random.default_rng(40)) + expected = sum( + [pa_table.to_pylist() for _, pa_table in generate_tables_fn(filepaths=["1.txt", "0.txt"])], [] + ) # shuffle the filepaths + assert [example for _, example in ex_iterable] == expected + expected = list(generate_tables_fn(filepaths=["1.txt", "0.txt"])) + assert list(ex_iterable.iter_arrow()) == expected + + +def test_python_to_arrow_examples_iterable(): + python_ex_iterable = ExamplesIterable(generate_examples_fn, {}) + arrow_ex_iterable = PythonToArrowExamplesIterable(python_ex_iterable) + assert list(python_ex_iterable) == list(arrow_ex_iterable) + tables = [pa_table for _, pa_table in arrow_ex_iterable.iter_arrow()] + assert 1 < len(tables) < len(list(python_ex_iterable)) + assert pa.Table.from_pylist([example for _, example in python_ex_iterable]) == pa.concat_tables(tables) + + @pytest.mark.parametrize("seed", [42, 1337, 101010, 123456]) def test_buffer_shuffled_examples_iterable(seed): n, buffer_size = 100, 30