Skip to content

Commit

Permalink
add iterable arrow formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed May 2, 2023
1 parent 649d5a3 commit b860cf6
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 17 deletions.
23 changes: 10 additions & 13 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
update_fingerprint,
validate_fingerprint,
)
from .formatting import PythonFormatter, format_table, get_format_type_from_alias, get_formatter, query_table
from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table
from .formatting.formatting import LazyDict, _is_range_contiguous
from .info import DatasetInfo, DatasetInfosDict
from .naming import _split_re
Expand Down Expand Up @@ -4933,16 +4933,10 @@ def extra_nbytes_visitor(array, feature):
return dataset_nbytes

@staticmethod
def _generate_examples_from_shards(shards: List["Dataset"]):
python_formatter = PythonFormatter()
for shards_idx, shard in enumerate(shards):
example_idx = 0
for pa_table in shard.with_format("arrow").iter(config.ARROW_READER_BATCH_SIZE_IN_DATASET_ITER):
batch = python_formatter.format_batch(pa_table)
for i in range(len(pa_table)):
example = {col: array[i] for col, array in batch.items()}
yield f"{shards_idx}_{example_idx}", example
example_idx += 1
def _generate_tables_from_shards(shards: List["Dataset"], batch_size: int):
for shard_idx, shard in enumerate(shards):
for pa_table in shard.with_format("arrow").iter(batch_size):
yield shard_idx, pa_table

def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset":
"""Get an [`datasets.IterableDataset`] from a map-style [`datasets.Dataset`].
Expand Down Expand Up @@ -5035,7 +5029,7 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset
```
Feel free to also use [`IterableDataset.set_epoch`] when using a PyTorch DataLoader or in distributed setups.
"""
from .iterable_dataset import ExamplesIterable, IterableDataset
from .iterable_dataset import ArrowExamplesIterable, IterableDataset

if self._format_type is not None:
raise NotImplementedError(
Expand All @@ -5057,7 +5051,10 @@ def to_iterable_dataset(self, num_shards: Optional[int] = 1) -> "IterableDataset
self.shard(num_shards=num_shards, index=shard_idx, contiguous=True) for shard_idx in range(num_shards)
]
)
ex_iterable = ExamplesIterable(Dataset._generate_examples_from_shards, kwargs={"shards": shards})
ex_iterable = ArrowExamplesIterable(
Dataset._generate_tables_from_shards,
kwargs={"shards": shards, "batch_size": config.DEFAULT_MAX_BATCH_SIZE},
)
return IterableDataset(ex_iterable, info=DatasetInfo(features=self.features))

def _push_parquet_shards_to_hub(
Expand Down
Loading

1 comment on commit b860cf6

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==8.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.010052 / 0.011353 (-0.001301) 0.006442 / 0.011008 (-0.004567) 0.138063 / 0.038508 (0.099555) 0.042106 / 0.023109 (0.018997) 0.421951 / 0.275898 (0.146053) 0.450478 / 0.323480 (0.126998) 0.007873 / 0.007986 (-0.000113) 0.008161 / 0.004328 (0.003832) 0.106825 / 0.004250 (0.102574) 0.059202 / 0.037052 (0.022149) 0.410885 / 0.258489 (0.152396) 0.471401 / 0.293841 (0.177560) 0.055427 / 0.128546 (-0.073119) 0.020664 / 0.075646 (-0.054982) 0.456043 / 0.419271 (0.036771) 0.070134 / 0.043533 (0.026601) 0.431497 / 0.255139 (0.176358) 0.441648 / 0.283200 (0.158449) 0.127418 / 0.141683 (-0.014265) 1.923977 / 1.452155 (0.471822) 2.014502 / 1.492716 (0.521786)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.287006 / 0.018006 (0.269000) 0.602916 / 0.000490 (0.602426) 0.005770 / 0.000200 (0.005570) 0.000136 / 0.000054 (0.000082)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.032079 / 0.037411 (-0.005333) 0.141242 / 0.014526 (0.126717) 0.141444 / 0.176557 (-0.035112) 0.215335 / 0.737135 (-0.521800) 0.149939 / 0.296338 (-0.146399)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.667005 / 0.215209 (0.451796) 6.769173 / 2.077655 (4.691518) 2.660283 / 1.504120 (1.156163) 2.268852 / 1.541195 (0.727658) 2.314830 / 1.468490 (0.846339) 1.370976 / 4.584777 (-3.213801) 5.971542 / 3.745712 (2.225830) 3.432592 / 5.269862 (-1.837269) 2.270747 / 4.565676 (-2.294930) 0.147430 / 0.424275 (-0.276845) 0.015880 / 0.007607 (0.008273) 0.856897 / 0.226044 (0.630852) 8.417554 / 2.268929 (6.148626) 3.493223 / 55.444624 (-51.951401) 2.737378 / 6.876477 (-4.139099) 2.930728 / 2.142072 (0.788655) 1.547808 / 4.805227 (-3.257419) 0.286437 / 6.500664 (-6.214227) 0.093142 / 0.075469 (0.017673)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.634655 / 1.841788 (-0.207133) 19.956731 / 8.074308 (11.882423) 25.251384 / 10.191392 (15.059992) 0.247999 / 0.680424 (-0.432425) 0.030560 / 0.534201 (-0.503641) 0.570228 / 0.579283 (-0.009055) 0.661190 / 0.434364 (0.226826) 0.707758 / 0.540337 (0.167421) 0.849178 / 1.386936 (-0.537758)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.012790 / 0.011353 (0.001437) 0.007049 / 0.011008 (-0.003959) 0.107943 / 0.038508 (0.069435) 0.040931 / 0.023109 (0.017822) 0.509168 / 0.275898 (0.233270) 0.559844 / 0.323480 (0.236365) 0.008140 / 0.007986 (0.000155) 0.005362 / 0.004328 (0.001034) 0.104986 / 0.004250 (0.100735) 0.055270 / 0.037052 (0.018218) 0.498111 / 0.258489 (0.239622) 0.565952 / 0.293841 (0.272111) 0.052865 / 0.128546 (-0.075681) 0.022221 / 0.075646 (-0.053425) 0.128766 / 0.419271 (-0.290505) 0.059960 / 0.043533 (0.016427) 0.513084 / 0.255139 (0.257945) 0.559026 / 0.283200 (0.275826) 0.135499 / 0.141683 (-0.006184) 1.911515 / 1.452155 (0.459360) 2.090097 / 1.492716 (0.597381)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.297088 / 0.018006 (0.279082) 0.609278 / 0.000490 (0.608788) 0.000506 / 0.000200 (0.000306) 0.000088 / 0.000054 (0.000034)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.035422 / 0.037411 (-0.001989) 0.143779 / 0.014526 (0.129253) 0.153539 / 0.176557 (-0.023018) 0.220643 / 0.737135 (-0.516492) 0.159192 / 0.296338 (-0.137147)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.675406 / 0.215209 (0.460197) 6.929684 / 2.077655 (4.852029) 3.224349 / 1.504120 (1.720229) 2.799787 / 1.541195 (1.258592) 2.812097 / 1.468490 (1.343606) 1.346006 / 4.584777 (-3.238771) 5.983092 / 3.745712 (2.237380) 3.385956 / 5.269862 (-1.883905) 2.189070 / 4.565676 (-2.376607) 0.158443 / 0.424275 (-0.265832) 0.015650 / 0.007607 (0.008043) 0.860163 / 0.226044 (0.634118) 8.725464 / 2.268929 (6.456535) 3.870916 / 55.444624 (-51.573709) 3.246533 / 6.876477 (-3.629944) 3.497942 / 2.142072 (1.355869) 1.556252 / 4.805227 (-3.248976) 0.278982 / 6.500664 (-6.221682) 0.092569 / 0.075469 (0.017100)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.826874 / 1.841788 (-0.014913) 20.119990 / 8.074308 (12.045681) 23.380837 / 10.191392 (13.189445) 0.247117 / 0.680424 (-0.433307) 0.031790 / 0.534201 (-0.502411) 0.581225 / 0.579283 (0.001942) 0.644015 / 0.434364 (0.209651) 0.645386 / 0.540337 (0.105048) 0.788338 / 1.386936 (-0.598599)

Please sign in to comment.