Skip to content

Commit

Permalink
some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed May 2, 2023
1 parent b860cf6 commit a905a19
Showing 1 changed file with 99 additions and 0 deletions.
99 changes: 99 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

from datasets import load_dataset
Expand All @@ -11,18 +12,21 @@
from datasets.formatting import get_format_type_from_alias
from datasets.info import DatasetInfo
from datasets.iterable_dataset import (
ArrowExamplesIterable,
BufferShuffledExamplesIterable,
CyclingMultiSourcesExamplesIterable,
ExamplesIterable,
FilteredExamplesIterable,
HorizontallyConcatenatedMultiSourcesExamplesIterable,
IterableDataset,
MappedExamplesIterable,
PythonToArrowExamplesIterable,
RandomlyCyclingMultiSourcesExamplesIterable,
ShufflingConfig,
SkipExamplesIterable,
TakeExamplesIterable,
VerticallyConcatenatedMultiSourcesExamplesIterable,
_batch_arrow_tables,
_batch_to_examples,
_examples_to_batch,
)
Expand All @@ -31,6 +35,7 @@


DEFAULT_N_EXAMPLES = 20
DEFAULT_BATCH_SIZE = 4
DEFAULT_FILEPATH = "file.txt"

SAMPLE_DATASET_IDENTIFIER = "lhoestq/test" # has dataset script
Expand All @@ -47,6 +52,25 @@ def generate_examples_fn(**kwargs):
yield f"{filepath}_{i}", {"id": i, **kwargs}


def generate_tables_fn(**kwargs):
kwargs = kwargs.copy()
n = kwargs.pop("n", DEFAULT_N_EXAMPLES)
batch_size = kwargs.pop("batch_size", DEFAULT_BATCH_SIZE)
filepaths = kwargs.pop("filepaths", None)
for filepath in filepaths or [DEFAULT_FILEPATH]:
buffer = []
batch_idx = 0
if filepaths is not None:
kwargs["filepath"] = filepath
for i in range(n):
buffer.append({"id": i, **kwargs})
if len(buffer) == batch_size:
yield f"{filepath}_{batch_idx}", pa.Table.from_pylist(buffer)
buffer = []
batch_idx += 1
yield batch_idx, pa.Table.from_pylist(buffer)


@pytest.fixture
def dataset():
ex_iterable = ExamplesIterable(generate_examples_fn, {})
Expand All @@ -62,6 +86,39 @@ def dataset_with_several_columns():
return IterableDataset(ex_iterable, info=DatasetInfo(description="dummy"), split="train")


################################
#
# Utilities tests
#
################################


@pytest.mark.parametrize(
"tables",
[
[pa.table({"foo": range(10)})],
[pa.table({"foo": range(0, 5)}), pa.table({"foo": range(5, 10)})],
[pa.table({"foo": [i]}) for i in range(10)],
],
)
@pytest.mark.parametrize("batch_size", [1, 2, 3, 9, 10, 11, 20])
@pytest.mark.parametrize("drop_last_batch", [False, True])
def test_batch_arrow_tables(tables, batch_size, drop_last_batch):
full_table = pa.concat_tables(tables)
num_rows = len(full_table) if not drop_last_batch else len(full_table) // batch_size * batch_size
num_batches = (num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size
subtables = list(_batch_arrow_tables(tables, batch_size=batch_size, drop_last_batch=drop_last_batch))
assert len(subtables) == num_batches
if drop_last_batch:
assert all(len(subtable) == batch_size for subtable in subtables)
else:
assert all(len(subtable) == batch_size for subtable in subtables[:-1])
assert len(subtables[-1]) <= batch_size
if num_rows > 0:
reloaded = pa.concat_tables(subtables)
assert full_table.slice(0, num_rows).to_pydict() == reloaded.to_pydict()


################################
#
# _BaseExampleIterable tests
Expand All @@ -74,6 +131,7 @@ def test_examples_iterable():
expected = list(generate_examples_fn())
assert next(iter(ex_iterable)) == expected[0]
assert list(ex_iterable) == expected
assert ex_iterable.iter_arrow is None


def test_examples_iterable_with_kwargs():
Expand Down Expand Up @@ -110,6 +168,47 @@ def gen(filepaths, all_metadata):
assert filepaths_ids == metadata_ids, "entangled lists of shards/metadata should be shuffled the same way"


def test_arrow_examples_iterable():
ex_iterable = ArrowExamplesIterable(generate_tables_fn, {})
expected = sum([pa_table.to_pylist() for _, pa_table in generate_tables_fn()], [])
assert next(iter(ex_iterable))[1] == expected[0]
assert [example for _, example in ex_iterable] == expected
expected = list(generate_tables_fn())
assert list(ex_iterable.iter_arrow()) == expected


def test_arrow_examples_iterable_with_kwargs():
ex_iterable = ArrowExamplesIterable(generate_tables_fn, {"filepaths": ["0.txt", "1.txt"], "split": "train"})
expected = sum(
[pa_table.to_pylist() for _, pa_table in generate_tables_fn(filepaths=["0.txt", "1.txt"], split="train")], []
)
assert [example for _, example in ex_iterable] == expected
assert all("split" in ex for _, ex in ex_iterable)
assert sorted({ex["filepath"] for _, ex in ex_iterable}) == ["0.txt", "1.txt"]
expected = list(generate_tables_fn(filepaths=["0.txt", "1.txt"], split="train"))
assert list(ex_iterable.iter_arrow()) == expected


def test_arrow_examples_iterable_shuffle_data_sources():
ex_iterable = ArrowExamplesIterable(generate_tables_fn, {"filepaths": ["0.txt", "1.txt"]})
ex_iterable = ex_iterable.shuffle_data_sources(np.random.default_rng(40))
expected = sum(
[pa_table.to_pylist() for _, pa_table in generate_tables_fn(filepaths=["1.txt", "0.txt"])], []
) # shuffle the filepaths
assert [example for _, example in ex_iterable] == expected
expected = list(generate_tables_fn(filepaths=["1.txt", "0.txt"]))
assert list(ex_iterable.iter_arrow()) == expected


def test_python_to_arrow_examples_iterable():
python_ex_iterable = ExamplesIterable(generate_examples_fn, {})
arrow_ex_iterable = PythonToArrowExamplesIterable(python_ex_iterable)
assert list(python_ex_iterable) == list(arrow_ex_iterable)
tables = [pa_table for _, pa_table in arrow_ex_iterable.iter_arrow()]
assert 1 < len(tables) < len(list(python_ex_iterable))
assert pa.Table.from_pylist([example for _, example in python_ex_iterable]) == pa.concat_tables(tables)


@pytest.mark.parametrize("seed", [42, 1337, 101010, 123456])
def test_buffer_shuffled_examples_iterable(seed):
n, buffer_size = 100, 30
Expand Down

1 comment on commit a905a19

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==8.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.006639 / 0.011353 (-0.004714) 0.004483 / 0.011008 (-0.006526) 0.098749 / 0.038508 (0.060241) 0.028064 / 0.023109 (0.004955) 0.348780 / 0.275898 (0.072882) 0.406462 / 0.323480 (0.082982) 0.004997 / 0.007986 (-0.002988) 0.004616 / 0.004328 (0.000287) 0.077603 / 0.004250 (0.073352) 0.039602 / 0.037052 (0.002550) 0.373884 / 0.258489 (0.115395) 0.412810 / 0.293841 (0.118969) 0.030045 / 0.128546 (-0.098502) 0.011442 / 0.075646 (-0.064204) 0.322611 / 0.419271 (-0.096660) 0.042686 / 0.043533 (-0.000847) 0.345500 / 0.255139 (0.090361) 0.390711 / 0.283200 (0.107511) 0.090020 / 0.141683 (-0.051663) 1.462101 / 1.452155 (0.009947) 1.541990 / 1.492716 (0.049274)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.196683 / 0.018006 (0.178677) 0.409627 / 0.000490 (0.409137) 0.000324 / 0.000200 (0.000125) 0.000060 / 0.000054 (0.000005)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.023133 / 0.037411 (-0.014279) 0.099351 / 0.014526 (0.084825) 0.104848 / 0.176557 (-0.071708) 0.163629 / 0.737135 (-0.573506) 0.108955 / 0.296338 (-0.187384)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.462165 / 0.215209 (0.246956) 4.599674 / 2.077655 (2.522019) 2.207452 / 1.504120 (0.703332) 1.984886 / 1.541195 (0.443691) 2.062392 / 1.468490 (0.593902) 0.700187 / 4.584777 (-3.884590) 3.389194 / 3.745712 (-0.356518) 1.870007 / 5.269862 (-3.399854) 1.167780 / 4.565676 (-3.397897) 0.083597 / 0.424275 (-0.340678) 0.012514 / 0.007607 (0.004907) 0.565803 / 0.226044 (0.339758) 5.691154 / 2.268929 (3.422226) 2.670939 / 55.444624 (-52.773685) 2.331220 / 6.876477 (-4.545257) 2.447013 / 2.142072 (0.304940) 0.807979 / 4.805227 (-3.997248) 0.152627 / 6.500664 (-6.348037) 0.067586 / 0.075469 (-0.007883)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.196279 / 1.841788 (-0.645509) 13.615878 / 8.074308 (5.541570) 14.055406 / 10.191392 (3.864014) 0.143053 / 0.680424 (-0.537371) 0.016554 / 0.534201 (-0.517647) 0.379872 / 0.579283 (-0.199411) 0.377914 / 0.434364 (-0.056450) 0.443907 / 0.540337 (-0.096430) 0.519589 / 1.386936 (-0.867347)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.006616 / 0.011353 (-0.004737) 0.004608 / 0.011008 (-0.006400) 0.076517 / 0.038508 (0.038009) 0.028378 / 0.023109 (0.005269) 0.360746 / 0.275898 (0.084848) 0.399471 / 0.323480 (0.075991) 0.005128 / 0.007986 (-0.002857) 0.003385 / 0.004328 (-0.000943) 0.075087 / 0.004250 (0.070836) 0.038537 / 0.037052 (0.001485) 0.361203 / 0.258489 (0.102714) 0.409056 / 0.293841 (0.115215) 0.030911 / 0.128546 (-0.097635) 0.011609 / 0.075646 (-0.064037) 0.085032 / 0.419271 (-0.334240) 0.040566 / 0.043533 (-0.002967) 0.357468 / 0.255139 (0.102329) 0.380218 / 0.283200 (0.097019) 0.096357 / 0.141683 (-0.045326) 1.493651 / 1.452155 (0.041496) 1.629957 / 1.492716 (0.137241)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.206806 / 0.018006 (0.188800) 0.413668 / 0.000490 (0.413179) 0.000398 / 0.000200 (0.000198) 0.000058 / 0.000054 (0.000004)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.025954 / 0.037411 (-0.011457) 0.102968 / 0.014526 (0.088443) 0.109725 / 0.176557 (-0.066832) 0.161428 / 0.737135 (-0.575708) 0.112945 / 0.296338 (-0.183394)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.451010 / 0.215209 (0.235801) 4.528029 / 2.077655 (2.450374) 2.204665 / 1.504120 (0.700545) 2.043047 / 1.541195 (0.501852) 2.174045 / 1.468490 (0.705555) 0.693503 / 4.584777 (-3.891274) 3.377672 / 3.745712 (-0.368040) 1.878218 / 5.269862 (-3.391643) 1.161110 / 4.565676 (-3.404567) 0.082983 / 0.424275 (-0.341292) 0.012380 / 0.007607 (0.004773) 0.549091 / 0.226044 (0.323047) 5.496317 / 2.268929 (3.227388) 2.658465 / 55.444624 (-52.786160) 2.406736 / 6.876477 (-4.469741) 2.581959 / 2.142072 (0.439887) 0.811179 / 4.805227 (-3.994049) 0.153234 / 6.500664 (-6.347430) 0.068379 / 0.075469 (-0.007090)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.342339 / 1.841788 (-0.499449) 14.269001 / 8.074308 (6.194693) 14.213221 / 10.191392 (4.021829) 0.173370 / 0.680424 (-0.507053) 0.016874 / 0.534201 (-0.517327) 0.390088 / 0.579283 (-0.189195) 0.428426 / 0.434364 (-0.005938) 0.453240 / 0.540337 (-0.087097) 0.538492 / 1.386936 (-0.848444)

Please sign in to comment.