diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 84e005f3914..37de71824b6 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -494,7 +494,7 @@ def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str] and not is_remote_url(self._cache_dir_root) and not (set(self.config_kwargs) - {"data_files", "data_dir"}) ): - from .packaged_modules import _PACKAGED_DATASETS_MODULES + from .packaged_modules import _PACKAGED_DATASETS_MODULES_2_15_HASHES from .utils._dill import Pickler def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str: @@ -516,7 +516,7 @@ def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> st namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True): config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files}) - hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1] + hash = _PACKAGED_DATASETS_MODULES_2_15_HASHES.get(self.name, "missing") if ( dataset_module.builder_configs_parameters.metadata_configs and self.config.name in dataset_module.builder_configs_parameters.metadata_configs diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 984dc0f03a3..3513f9ae59e 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -43,6 +43,18 @@ def _hash_python_lines(lines: List[str]) -> str: "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), } +# get importable module names and hash for caching +_PACKAGED_DATASETS_MODULES_2_15_HASHES = { + "csv": "eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d", + "json": "8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96", + "pandas": "3ac4ffc4563c796122ef66899b9485a3f1a977553e2d2a8a318c72b8cc6f2202", + "parquet": "ca31c69184d9832faed373922c2acccec0b13a0bb5bbbe19371385c3ff26f1d1", + "arrow": "74f69db2c14c2860059d39860b1f400a03d11bf7fb5a8258ca38c501c878c137", + "text": "c4a140d10f020282918b5dd1b8a49f0104729c6177f60a6b49ec2a365ec69f34", + "imagefolder": "7b7ce5247a942be131d49ad4f3de5866083399a0f250901bd8dc202f8c5f7ce5", + "audiofolder": "d3c1655c66c8f72e4efb5c79e952975fa6e2ce538473a6890241ddbddee9071c", +} + # Used to infer the module to use based on the data files extensions _EXTENSION_TO_MODULE: Dict[str, Tuple[str, dict]] = { ".csv": ("csv", {}), diff --git a/src/datasets/packaged_modules/arrow/arrow.py b/src/datasets/packaged_modules/arrow/arrow.py index e2f7108dc3e..223a74602c3 100644 --- a/src/datasets/packaged_modules/arrow/arrow.py +++ b/src/datasets/packaged_modules/arrow/arrow.py @@ -17,6 +17,9 @@ class ArrowConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None + def __post_init__(self): + super().__post_init__() + class Arrow(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = ArrowConfig diff --git a/src/datasets/packaged_modules/audiofolder/audiofolder.py b/src/datasets/packaged_modules/audiofolder/audiofolder.py index 02f9313cf9a..338fa09d755 100644 --- a/src/datasets/packaged_modules/audiofolder/audiofolder.py +++ b/src/datasets/packaged_modules/audiofolder/audiofolder.py @@ -15,6 +15,9 @@ class AudioFolderConfig(folder_based_builder.FolderBasedBuilderConfig): drop_labels: bool = None drop_metadata: bool = None + def __post_init__(self): + super().__post_init__() + class AudioFolder(folder_based_builder.FolderBasedBuilder): BASE_FEATURE = datasets.Audio diff --git a/src/datasets/packaged_modules/csv/csv.py b/src/datasets/packaged_modules/csv/csv.py index e877ae440b7..16f004aa3b4 100644 --- a/src/datasets/packaged_modules/csv/csv.py +++ b/src/datasets/packaged_modules/csv/csv.py @@ -68,6 +68,7 @@ class CsvConfig(datasets.BuilderConfig): date_format: Optional[str] = None def __post_init__(self): + super().__post_init__() if self.delimiter is not None: self.sep = self.delimiter if self.column_names is not None: diff --git a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py index 24c32a746e8..7b71dc407ae 100644 --- a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py +++ b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py @@ -28,6 +28,9 @@ class FolderBasedBuilderConfig(datasets.BuilderConfig): drop_labels: bool = None drop_metadata: bool = None + def __post_init__(self): + super().__post_init__() + class FolderBasedBuilder(datasets.GeneratorBasedBuilder): """ diff --git a/src/datasets/packaged_modules/generator/generator.py b/src/datasets/packaged_modules/generator/generator.py index 1efa721b159..336942f2edc 100644 --- a/src/datasets/packaged_modules/generator/generator.py +++ b/src/datasets/packaged_modules/generator/generator.py @@ -11,7 +11,9 @@ class GeneratorConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None def __post_init__(self): - assert self.generator is not None, "generator must be specified" + super().__post_init__() + if self.generator is None: + raise ValueError("generator must be specified") if self.gen_kwargs is None: self.gen_kwargs = {} diff --git a/src/datasets/packaged_modules/imagefolder/imagefolder.py b/src/datasets/packaged_modules/imagefolder/imagefolder.py index bd2dd0d419a..16fbcd005d4 100644 --- a/src/datasets/packaged_modules/imagefolder/imagefolder.py +++ b/src/datasets/packaged_modules/imagefolder/imagefolder.py @@ -15,6 +15,9 @@ class ImageFolderConfig(folder_based_builder.FolderBasedBuilderConfig): drop_labels: bool = None drop_metadata: bool = None + def __post_init__(self): + super().__post_init__() + class ImageFolder(folder_based_builder.FolderBasedBuilder): BASE_FEATURE = datasets.Image diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 0a9a5b71949..40539b625c9 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -44,6 +44,9 @@ class JsonConfig(datasets.BuilderConfig): chunksize: int = 10 << 20 # 10MB newlines_in_values: Optional[bool] = None + def __post_init__(self): + super().__post_init__() + class Json(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = JsonConfig diff --git a/src/datasets/packaged_modules/pandas/pandas.py b/src/datasets/packaged_modules/pandas/pandas.py index c17f389945e..d1eb50d33c8 100644 --- a/src/datasets/packaged_modules/pandas/pandas.py +++ b/src/datasets/packaged_modules/pandas/pandas.py @@ -16,6 +16,9 @@ class PandasConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None + def __post_init__(self): + super().__post_init__() + class Pandas(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = PandasConfig diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 9ecbd4606fa..f6ec2e06cc0 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -20,6 +20,9 @@ class ParquetConfig(datasets.BuilderConfig): columns: Optional[List[str]] = None features: Optional[datasets.Features] = None + def __post_init__(self): + super().__post_init__() + class Parquet(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = ParquetConfig diff --git a/src/datasets/packaged_modules/spark/spark.py b/src/datasets/packaged_modules/spark/spark.py index 5c0c836a41e..c21cb3dd981 100644 --- a/src/datasets/packaged_modules/spark/spark.py +++ b/src/datasets/packaged_modules/spark/spark.py @@ -33,6 +33,9 @@ class SparkConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None + def __post_init__(self): + super().__post_init__() + def _reorder_dataframe_by_partition(df: "pyspark.sql.DataFrame", new_partition_order: List[int]): df_combined = df.select("*").where(f"part_id = {new_partition_order[0]}") diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index b0791ba8859..152a8dc2089 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -35,6 +35,7 @@ class SqlConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None def __post_init__(self): + super().__post_init__() if self.sql is None: raise ValueError("sql must be specified") if self.con is None: diff --git a/src/datasets/packaged_modules/text/text.py b/src/datasets/packaged_modules/text/text.py index 13669bb5b17..e0e0d7fedb9 100644 --- a/src/datasets/packaged_modules/text/text.py +++ b/src/datasets/packaged_modules/text/text.py @@ -27,6 +27,7 @@ class TextConfig(datasets.BuilderConfig): sample_by: str = "line" def __post_init__(self, errors): + super().__post_init__() if errors != "deprecated": warnings.warn( "'errors' was deprecated in favor of 'encoding_errors' in version 2.14.0 and will be removed in 3.0.0.\n" diff --git a/tests/packaged_modules/test_arrow.py b/tests/packaged_modules/test_arrow.py new file mode 100644 index 00000000000..dda3720efe3 --- /dev/null +++ b/tests/packaged_modules/test_arrow.py @@ -0,0 +1,16 @@ +import pytest + +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.arrow.arrow import ArrowConfig + + +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = ArrowConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = ArrowConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_audiofolder.py b/tests/packaged_modules/test_audiofolder.py index 712e6aeac4f..3351fccf604 100644 --- a/tests/packaged_modules/test_audiofolder.py +++ b/tests/packaged_modules/test_audiofolder.py @@ -7,9 +7,10 @@ import soundfile as sf from datasets import Audio, ClassLabel, Features, Value -from datasets.data_files import DataFilesDict, get_data_patterns +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns from datasets.download.streaming_download_manager import StreamingDownloadManager -from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder +from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder, AudioFolderConfig from ..utils import require_sndfile @@ -230,6 +231,17 @@ def data_files_with_zip_archives(tmp_path, audio_file): return data_files_with_zip_archives +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = AudioFolderConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = AudioFolderConfig(name="name", data_files=data_files) + + @require_sndfile # check that labels are inferred correctly from dir names def test_generate_examples_with_labels(data_files_with_labels_no_metadata, cache_dir): diff --git a/tests/packaged_modules/test_csv.py b/tests/packaged_modules/test_csv.py index 6cfa5e4ca23..e85dc1e3b09 100644 --- a/tests/packaged_modules/test_csv.py +++ b/tests/packaged_modules/test_csv.py @@ -5,7 +5,9 @@ import pytest from datasets import ClassLabel, Features, Image -from datasets.packaged_modules.csv.csv import Csv +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.csv.csv import Csv, CsvConfig from ..utils import require_pil @@ -86,6 +88,17 @@ def csv_file_with_int_list(tmp_path): return str(filename) +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = CsvConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = CsvConfig(name="name", data_files=data_files) + + def test_csv_generate_tables_raises_error_with_malformed_csv(csv_file, malformed_csv_file, caplog): csv = Csv() generator = csv._generate_tables([[csv_file, malformed_csv_file]]) diff --git a/tests/packaged_modules/test_folder_based_builder.py b/tests/packaged_modules/test_folder_based_builder.py index c6aad5ded09..3623c4b1680 100644 --- a/tests/packaged_modules/test_folder_based_builder.py +++ b/tests/packaged_modules/test_folder_based_builder.py @@ -5,7 +5,8 @@ import pytest from datasets import ClassLabel, DownloadManager, Features, Value -from datasets.data_files import DataFilesDict, get_data_patterns +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns from datasets.download.streaming_download_manager import StreamingDownloadManager from datasets.packaged_modules.folder_based_builder.folder_based_builder import ( FolderBasedBuilder, @@ -265,6 +266,17 @@ def data_files_with_zip_archives(tmp_path, auto_text_file): return data_files_with_zip_archives +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = FolderBasedBuilderConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = FolderBasedBuilderConfig(name="name", data_files=data_files) + + def test_inferring_labels_from_data_dirs(data_files_with_labels_no_metadata, cache_dir): autofolder = DummyFolderBasedBuilder( data_files=data_files_with_labels_no_metadata, cache_dir=cache_dir, drop_labels=False diff --git a/tests/packaged_modules/test_imagefolder.py b/tests/packaged_modules/test_imagefolder.py index 3be9195d6aa..835d3a7db0c 100644 --- a/tests/packaged_modules/test_imagefolder.py +++ b/tests/packaged_modules/test_imagefolder.py @@ -5,9 +5,10 @@ import pytest from datasets import ClassLabel, Features, Image, Value -from datasets.data_files import DataFilesDict, get_data_patterns +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns from datasets.download.streaming_download_manager import StreamingDownloadManager -from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder +from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder, ImageFolderConfig from ..utils import require_pil @@ -239,6 +240,17 @@ def data_files_with_zip_archives(tmp_path, image_file): return data_files_with_zip_archives +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = ImageFolderConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = ImageFolderConfig(name="name", data_files=data_files) + + @require_pil # check that labels are inferred correctly from dir names def test_generate_examples_with_labels(data_files_with_labels_no_metadata, cache_dir): diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py index a3259991078..07bbba7adec 100644 --- a/tests/packaged_modules/test_json.py +++ b/tests/packaged_modules/test_json.py @@ -4,7 +4,9 @@ import pytest from datasets import Features, Value -from datasets.packaged_modules.json.json import Json +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.json.json import Json, JsonConfig @pytest.fixture @@ -171,6 +173,17 @@ def json_file_with_list_of_dicts_with_sorted_columns_field(tmp_path): return str(path) +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = JsonConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = JsonConfig(name="name", data_files=data_files) + + @pytest.mark.parametrize( "file_fixture, config_kwargs", [ diff --git a/tests/packaged_modules/test_pandas.py b/tests/packaged_modules/test_pandas.py new file mode 100644 index 00000000000..60b3bb22107 --- /dev/null +++ b/tests/packaged_modules/test_pandas.py @@ -0,0 +1,16 @@ +import pytest + +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.pandas.pandas import PandasConfig + + +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = PandasConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = PandasConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_parquet.py b/tests/packaged_modules/test_parquet.py new file mode 100644 index 00000000000..b5c1808d8f9 --- /dev/null +++ b/tests/packaged_modules/test_parquet.py @@ -0,0 +1,16 @@ +import pytest + +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.parquet.parquet import ParquetConfig + + +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = ParquetConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = ParquetConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_spark.py b/tests/packaged_modules/test_spark.py index 7e11c02113f..c91bdd571ea 100644 --- a/tests/packaged_modules/test_spark.py +++ b/tests/packaged_modules/test_spark.py @@ -1,9 +1,13 @@ from unittest.mock import patch import pyspark +import pytest +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList from datasets.packaged_modules.spark.spark import ( Spark, + SparkConfig, SparkExamplesIterable, _generate_iterable_examples, ) @@ -23,6 +27,17 @@ def _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order) return expected_row_ids_and_row_dicts +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = SparkConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = SparkConfig(name="name", data_files=data_files) + + @require_not_windows @require_dill_gt_0_3_2 def test_repartition_df_if_needed(): diff --git a/tests/packaged_modules/test_sql.py b/tests/packaged_modules/test_sql.py new file mode 100644 index 00000000000..e745cb03d2e --- /dev/null +++ b/tests/packaged_modules/test_sql.py @@ -0,0 +1,16 @@ +import pytest + +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.sql.sql import SqlConfig + + +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = SqlConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = SqlConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_text.py b/tests/packaged_modules/test_text.py index 0d1b3f3b5a4..a21b3e223d9 100644 --- a/tests/packaged_modules/test_text.py +++ b/tests/packaged_modules/test_text.py @@ -4,7 +4,9 @@ import pytest from datasets import Features, Image -from datasets.packaged_modules.text.text import Text +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.text.text import Text, TextConfig from ..utils import require_pil @@ -39,6 +41,17 @@ def text_file_with_image(tmp_path, image_file): return str(filename) +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = TextConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = TextConfig(name="name", data_files=data_files) + + @pytest.mark.parametrize("keep_linebreaks", [True, False]) def test_text_linebreaks(text_file, keep_linebreaks): with open(text_file, encoding="utf-8") as f: diff --git a/tests/test_builder.py b/tests/test_builder.py index 81966044fc3..6698a79cbf8 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -17,7 +17,15 @@ from datasets.arrow_dataset import Dataset from datasets.arrow_reader import DatasetNotOnHfGcsError from datasets.arrow_writer import ArrowWriter -from datasets.builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder +from datasets.builder import ( + ArrowBasedBuilder, + BeamBasedBuilder, + BuilderConfig, + DatasetBuilder, + GeneratorBasedBuilder, + InvalidConfigName, +) +from datasets.data_files import DataFilesList from datasets.dataset_dict import DatasetDict, IterableDatasetDict from datasets.download.download_manager import DownloadMode from datasets.features import Features, Value @@ -836,6 +844,17 @@ def test_cache_dir_for_configured_builder(self): self.assertNotEqual(builder.cache_dir, other_builder.cache_dir) +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = BuilderConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = BuilderConfig(name="name", data_files=data_files) + + def test_arrow_based_download_and_prepare(tmp_path): builder = DummyArrowBasedBuilder(cache_dir=tmp_path) builder.download_and_prepare()