diff --git a/libs/libcommon/src/libcommon/constants.py b/libs/libcommon/src/libcommon/constants.py index d889497271..a39b64bc74 100644 --- a/libs/libcommon/src/libcommon/constants.py +++ b/libs/libcommon/src/libcommon/constants.py @@ -18,7 +18,7 @@ PROCESSING_STEP_DATASET_SIZE_VERSION = 2 PROCESSING_STEP_DATASET_SPLIT_NAMES_FROM_DATASET_INFO_VERSION = 2 PROCESSING_STEP_DATASET_SPLIT_NAMES_FROM_STREAMING_VERSION = 2 -PROCESSING_STEP_PARQUET_AND_DATASET_INFO_VERSION = 1 +PROCESSING_STEP_PARQUET_AND_DATASET_INFO_VERSION = 2 PROCESSING_STEP_SPLIT_FIRST_ROWS_FROM_PARQUET_VERSION = 2 PROCESSING_STEP_CONFIG_PARQUET_AND_INFO_VERSION = 2 PROCESSING_STEP_SPLIT_FIRST_ROWS_FROM_STREAMING_VERSION = 3 @@ -26,4 +26,6 @@ PROCESSING_STEP_SPLIT_NAMES_FROM_STREAMING_VERSION = 3 PROCESSING_STEP_DATASET_SPLIT_NAMES_VERSION = 2 +PROCESSING_STEP_PARQUET_AND_DATASET_INFO_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS = 100 +PROCESSING_STEP_PARQUET_AND_DATASET_INFO_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS = 100 PARQUET_REVISION = "refs/convert/parquet" diff --git a/services/worker/src/worker/job_runners/parquet_and_dataset_info.py b/services/worker/src/worker/job_runners/parquet_and_dataset_info.py index 6702deffd1..59ab32731f 100644 --- a/services/worker/src/worker/job_runners/parquet_and_dataset_info.py +++ b/services/worker/src/worker/job_runners/parquet_and_dataset_info.py @@ -14,6 +14,7 @@ import datasets import datasets.config +import datasets.info import numpy as np import requests from datasets import ( @@ -39,7 +40,11 @@ ) from huggingface_hub.hf_api import DatasetInfo, HfApi, RepoFile from huggingface_hub.utils._errors import RepositoryNotFoundError, RevisionNotFoundError -from libcommon.constants import PROCESSING_STEP_PARQUET_AND_DATASET_INFO_VERSION +from libcommon.constants import ( + PROCESSING_STEP_PARQUET_AND_DATASET_INFO_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS, + PROCESSING_STEP_PARQUET_AND_DATASET_INFO_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS, + PROCESSING_STEP_PARQUET_AND_DATASET_INFO_VERSION, +) from libcommon.dataset import DatasetNotFoundError, ask_access from libcommon.processing_graph import ProcessingStep from libcommon.queue import JobInfo @@ -634,6 +639,30 @@ def raise_if_too_big_from_external_data_files( ) from error +def get_writer_batch_size(ds_config_info: datasets.info.DatasetInfo) -> Optional[int]: + """ + Get the writer_batch_size that defines the maximum row group size in the parquet files. + The default in `datasets` is 1,000 but we lower it to 100 for image datasets. + This allows to optimize random access to parquet file, since accessing 1 row requires + to read its entire row group. + + Args: + ds_config_info (`datasets.info.DatasetInfo`): + Dataset info from `datasets`. + + Returns: + writer_batch_size (`Optional[int]`): + Writer batch size to pass to a dataset builder. + If `None`, then it will use the `datasets` default. + """ + if "Audio(" in str(ds_config_info.features): + return PROCESSING_STEP_PARQUET_AND_DATASET_INFO_ROW_GROUP_SIZE_FOR_AUDIO_DATASETS + elif "Image(" in str(ds_config_info.features): + return PROCESSING_STEP_PARQUET_AND_DATASET_INFO_ROW_GROUP_SIZE_FOR_IMAGE_DATASETS + else: + return None + + def compute_parquet_and_dataset_info_response( dataset: str, hf_endpoint: str, @@ -766,6 +795,11 @@ def compute_parquet_and_dataset_info_response( use_auth_token=hf_token, download_config=download_config, ) + writer_batch_size = get_writer_batch_size(builder.info) + if writer_batch_size is not None and ( + builder._writer_batch_size is None or builder._writer_batch_size > writer_batch_size + ): + builder._writer_batch_size = writer_batch_size raise_if_too_big_from_external_data_files( builder=builder, max_dataset_size=max_dataset_size, diff --git a/services/worker/tests/job_runners/test_parquet_and_dataset_info.py b/services/worker/tests/job_runners/test_parquet_and_dataset_info.py index 931ea63b67..bb126a392e 100644 --- a/services/worker/tests/job_runners/test_parquet_and_dataset_info.py +++ b/services/worker/tests/job_runners/test_parquet_and_dataset_info.py @@ -6,9 +6,11 @@ from typing import Any, Callable, Iterator, List, Optional import datasets.builder +import datasets.info import pandas as pd import pytest import requests +from datasets import Features, Image, Value from libcommon.exceptions import CustomError from libcommon.processing_graph import ProcessingStep from libcommon.queue import Priority @@ -24,6 +26,7 @@ DatasetWithTooManyExternalFilesError, ParquetAndDatasetInfoJobRunner, get_dataset_info_or_raise, + get_writer_batch_size, parse_repo_filename, raise_if_blocked, raise_if_not_supported, @@ -487,3 +490,16 @@ def test_parse_repo_filename(filename: str, split: str, config: str, raises: boo parse_repo_filename(filename) else: assert parse_repo_filename(filename) == (config, split) + + +@pytest.mark.parametrize( + "ds_info, with_image", + [ + (datasets.info.DatasetInfo(), False), + (datasets.info.DatasetInfo(features=Features({"text": Value("string")})), False), + (datasets.info.DatasetInfo(features=Features({"image": Image()})), True), + (datasets.info.DatasetInfo(features=Features({"nested": [{"image": Image()}]})), True), + ], +) +def test_get_writer_batch_size(ds_info: datasets.info.DatasetInfo, with_image: bool) -> None: + assert get_writer_batch_size(ds_info) == (100 if with_image else None)