Skip to content

Commit

Permalink
Support streaming Beam datasets from HF GCS preprocessed data (#5689)
Browse files Browse the repository at this point in the history
* Remove error when streaming BeamBasedBuilder

* Implement BeamBasedBuilder.as_streaming_dataset

* Test BeamBasedBuilder.as_streaming_dataset raises

* Remove draft comment

* Move local imports to module level

* Test as_streaming_dataset on wikipedia

* Remove unnecessary beam_runner from test

* Refactor tests

* Refactor BeamBasedBuilder

* Refactor test

* Refactor to support sharded files
  • Loading branch information
albertvillanova authored Apr 12, 2023
1 parent f260793 commit ce06edf
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 13 deletions.
63 changes: 59 additions & 4 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing import Dict, Iterable, Mapping, Optional, Tuple, Union

import fsspec
import pyarrow as pa
from multiprocess import Pool
from tqdm.contrib.concurrent import thread_map

Expand All @@ -50,7 +51,7 @@
from .download.download_config import DownloadConfig
from .download.download_manager import DownloadManager, DownloadMode
from .download.mock_download_manager import MockDownloadManager
from .download.streaming_download_manager import StreamingDownloadManager
from .download.streaming_download_manager import StreamingDownloadManager, xopen
from .features import Features
from .filesystems import is_remote_filesystem
from .fingerprint import Hasher
Expand Down Expand Up @@ -1245,9 +1246,6 @@ def as_streaming_dataset(
split: Optional[str] = None,
base_path: Optional[str] = None,
) -> Union[Dict[str, IterableDataset], IterableDataset]:
if not isinstance(self, (GeneratorBasedBuilder, ArrowBasedBuilder)):
raise ValueError(f"Builder {self.name} is not streamable.")

is_local = not is_remote_filesystem(self._fs)
if not is_local:
raise NotImplementedError(
Expand Down Expand Up @@ -2081,3 +2079,60 @@ def _build_pcollection(pipeline):

# Add the PCollection to the pipeline
_ = pipeline | split_name >> _build_pcollection() # pylint: disable=no-value-for-parameter max_bytes_per_shard

def as_streaming_dataset(
self,
split: Optional[str] = None,
) -> Union[Dict[str, IterableDataset], IterableDataset]:
self._request_info_from_hf_gcs()
datasets = {
split.name: IterableDataset(self._get_examples_iterable_for_split(split), info=self.info, split=split.name)
for split in self.info.splits.values()
}
if split:
try:
datasets = datasets[split]
except KeyError:
raise ValueError(f"Bad split: {split}. Available splits: {list(datasets)}")
if isinstance(datasets, dict):
datasets = IterableDatasetDict(datasets)
return datasets

def _get_examples_iterable_for_split(self, split: SplitInfo) -> ExamplesIterable:
return ExamplesIterable(self._generate_examples_from_hf_gcs, {"split": split})

def _generate_examples_from_hf_gcs(self, split: SplitInfo):
if split.shard_lengths:
num_shards = len(split.shard_lengths)
remote_prepared_urls = [
f"{self._remote_cache_dir_from_hf_gcs}/{self.name}-{split.name}-{shard_id:05d}-of-{num_shards:05d}.arrow"
for shard_id in range(num_shards)
]
else:
remote_prepared_urls = [f"{self._remote_cache_dir_from_hf_gcs}/{self.name}-{split.name}.arrow"]
key = 0
for remote_prepared_url in remote_prepared_urls:
with xopen(remote_prepared_url, "rb") as f:
with pa.ipc.open_stream(f) as reader:
for record_batch in reader:
for record in record_batch.to_pylist():
yield key, record
key += 1

def _request_info_from_hf_gcs(self):
from .download.streaming_download_manager import xopen

remote_dataset_info = f"{self._remote_cache_dir_from_hf_gcs}/{config.DATASET_INFO_FILENAME}"
try:
with xopen(remote_dataset_info) as f:
import json

_info = json.load(f)
except FileNotFoundError as err:
raise DatasetNotOnHfGcsError(err) from None
self.info.update(DatasetInfo.from_dict(_info))

@property
def _remote_cache_dir_from_hf_gcs(self):
relative_data_dir = self._relative_data_dir(with_hash=False)
return HF_GCP_BASE_URL + "/" + relative_data_dir.replace(os.sep, "/")
9 changes: 9 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from multiprocess.pool import Pool

from datasets.arrow_dataset import Dataset
from datasets.arrow_reader import DatasetNotOnHfGcsError
from datasets.arrow_writer import ArrowWriter
from datasets.builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
Expand Down Expand Up @@ -935,6 +936,14 @@ def test_builder_as_streaming_dataset(tmp_path):
assert len(list(dset)) == 100


@require_beam
def test_beam_based_builder_as_streaming_dataset(tmp_path):
builder = DummyBeamBasedBuilder(cache_dir=tmp_path)
check_streaming(builder)
with pytest.raises(DatasetNotOnHfGcsError):
builder.as_streaming_dataset()


def _run_test_builder_streaming_works_in_subprocesses(builder):
check_streaming(builder)
dset = builder.as_streaming_dataset(split="train")
Expand Down
38 changes: 29 additions & 9 deletions tests/test_hf_gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from datasets import config
from datasets.arrow_reader import HF_GCP_BASE_URL
from datasets.builder import DatasetBuilder
from datasets.dataset_dict import IterableDatasetDict
from datasets.iterable_dataset import IterableDataset
from datasets.load import dataset_module_factory, import_main_class
from datasets.utils.file_utils import cached_path

Expand Down Expand Up @@ -62,28 +64,46 @@ def test_dataset_info_available(self, dataset, config_name):
hash=dataset_module.hash,
)

dataset_info_url = os.path.join(
HF_GCP_BASE_URL, builder_instance._relative_data_dir(with_hash=False), config.DATASET_INFO_FILENAME
).replace(os.sep, "/")
dataset_info_url = "/".join(
[
HF_GCP_BASE_URL,
builder_instance._relative_data_dir(with_hash=False).replace(os.sep, "/"),
config.DATASET_INFO_FILENAME,
]
)
datset_info_path = cached_path(dataset_info_url, cache_dir=tmp_dir)
self.assertTrue(os.path.exists(datset_info_path))


@pytest.mark.integration
def test_wikipedia_frr(tmp_path_factory):
def test_as_dataset_from_hf_gcs(tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("test_hf_gcp") / "test_wikipedia_simple"
dataset_module = dataset_module_factory("wikipedia", cache_dir=tmp_dir)

builder_cls = import_main_class(dataset_module.module_path, dataset=True)

builder_cls = import_main_class(dataset_module.module_path)
builder_instance: DatasetBuilder = builder_cls(
cache_dir=tmp_dir,
config_name="20220301.frr",
hash=dataset_module.hash,
)

# use the HF cloud storage, not the original download_and_prepare that uses apache-beam
builder_instance._download_and_prepare = None
builder_instance.download_and_prepare()
ds = builder_instance.as_dataset()
assert ds is not None
assert ds


@pytest.mark.integration
def test_as_streaming_dataset_from_hf_gcs(tmp_path):
dataset_module = dataset_module_factory("wikipedia", cache_dir=tmp_path)
builder_cls = import_main_class(dataset_module.module_path, dataset=True)
builder_instance: DatasetBuilder = builder_cls(
cache_dir=tmp_path,
config_name="20220301.frr",
hash=dataset_module.hash,
)
ds = builder_instance.as_streaming_dataset()
assert ds
assert isinstance(ds, IterableDatasetDict)
assert "train" in ds
assert isinstance(ds["train"], IterableDataset)
assert next(iter(ds["train"]))

0 comments on commit ce06edf

Please sign in to comment.