Skip to content

Commit

Permalink
Set load_from_disk path type as PathLike (#7081)
Browse files Browse the repository at this point in the history
* Set path type as PathLike in load_from_disk

* Update docstrings

* Update tests

* Update save_to_disk docstrings
  • Loading branch information
albertvillanova committed Aug 13, 2024
1 parent dd18570 commit e84ca7c
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 15 deletions.
6 changes: 3 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,7 +1460,7 @@ def save_to_disk(
If you want to store paths or urls, please use the Value("string") type.
Args:
dataset_path (`str`):
dataset_path (`path-like`):
Path (e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`)
of the dataset directory where the dataset will be saved to.
fs (`fsspec.spec.AbstractFileSystem`, *optional*):
Expand Down Expand Up @@ -1660,7 +1660,7 @@ def _build_local_temp_path(uri_or_path: str) -> Path:

@staticmethod
def load_from_disk(
dataset_path: str,
dataset_path: PathLike,
fs="deprecated",
keep_in_memory: Optional[bool] = None,
storage_options: Optional[dict] = None,
Expand All @@ -1670,7 +1670,7 @@ def load_from_disk(
filesystem using any implementation of `fsspec.spec.AbstractFileSystem`.
Args:
dataset_path (`str`):
dataset_path (`path-like`):
Path (e.g. `"dataset/train"`) or remote URI (e.g. `"s3//my-bucket/dataset/train"`)
of the dataset directory where the dataset will be loaded from.
fs (`fsspec.spec.AbstractFileSystem`, *optional*):
Expand Down
9 changes: 4 additions & 5 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,10 +1231,9 @@ def save_to_disk(
If you want to store paths or urls, please use the Value("string") type.
Args:
dataset_dict_path (`str`):
Path (e.g. `dataset/train`) or remote URI
(e.g. `s3://my-bucket/dataset/train`) of the dataset dict directory where the dataset dict will be
saved to.
dataset_dict_path (`path-like`):
Path (e.g. `dataset/train`) or remote URI (e.g. `s3://my-bucket/dataset/train`)
of the dataset dict directory where the dataset dict will be saved to.
fs (`fsspec.spec.AbstractFileSystem`, *optional*):
Instance of the remote filesystem where the dataset will be saved to.
Expand Down Expand Up @@ -1314,7 +1313,7 @@ def load_from_disk(
Load a dataset that was previously saved using [`save_to_disk`] from a filesystem using `fsspec.spec.AbstractFileSystem`.
Args:
dataset_dict_path (`str`):
dataset_dict_path (`path-like`):
Path (e.g. `"dataset/train"`) or remote URI (e.g. `"s3//my-bucket/dataset/train"`)
of the dataset dict directory where the dataset dict will be loaded from.
fs (`fsspec.spec.AbstractFileSystem`, *optional*):
Expand Down
12 changes: 8 additions & 4 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from .utils.logging import get_logger
from .utils.metadata import MetadataConfigs
from .utils.py_utils import get_imports, lock_importable_file
from .utils.typing import PathLike
from .utils.version import Version


Expand Down Expand Up @@ -2648,16 +2649,19 @@ def load_dataset(


def load_from_disk(
dataset_path: str, fs="deprecated", keep_in_memory: Optional[bool] = None, storage_options: Optional[dict] = None
dataset_path: PathLike,
fs="deprecated",
keep_in_memory: Optional[bool] = None,
storage_options: Optional[dict] = None,
) -> Union[Dataset, DatasetDict]:
"""
Loads a dataset that was previously saved using [`~Dataset.save_to_disk`] from a dataset directory, or
from a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`.
Args:
dataset_path (`str`):
Path (e.g. `"dataset/train"`) or remote URI (e.g.
`"s3://my-bucket/dataset/train"`) of the [`Dataset`] or [`DatasetDict`] directory where the dataset will be
dataset_path (`path-like`):
Path (e.g. `"dataset/train"`) or remote URI (e.g. `"s3://my-bucket/dataset/train"`)
of the [`Dataset`] or [`DatasetDict`] directory where the dataset/dataset-dict will be
loaded from.
fs (`~filesystems.S3FileSystem` or `fsspec.spec.AbstractFileSystem`, *optional*):
Instance of the remote filesystem used to download the files from.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4119,7 +4119,7 @@ def test_dummy_dataset_serialize_fs(dataset, mockfs):
dataset.save_to_disk(dataset_path, storage_options=mockfs.storage_options)
assert mockfs.isdir(dataset_path)
assert mockfs.glob(dataset_path + "/*")
reloaded = dataset.load_from_disk(dataset_path, storage_options=mockfs.storage_options)
reloaded = Dataset.load_from_disk(dataset_path, storage_options=mockfs.storage_options)
assert len(reloaded) == len(dataset)
assert reloaded.features == dataset.features
assert reloaded.to_dict() == dataset.to_dict()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def test_dummy_datasetdict_serialize_fs(mockfs):
dataset_dict.save_to_disk(dataset_path, storage_options=mockfs.storage_options)
assert mockfs.isdir(dataset_path)
assert mockfs.glob(dataset_path + "/*")
reloaded = dataset_dict.load_from_disk(dataset_path, storage_options=mockfs.storage_options)
reloaded = DatasetDict.load_from_disk(dataset_path, storage_options=mockfs.storage_options)
assert list(reloaded) == list(dataset_dict)
for k in dataset_dict:
assert reloaded[k].features == dataset_dict[k].features
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,7 +1680,7 @@ def test_load_from_disk_with_default_in_memory(
expected_in_memory = False

dset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, keep_in_memory=True, trust_remote_code=True)
dataset_path = os.path.join(tmp_path, "saved_dataset")
dataset_path = tmp_path / "saved_dataset"
dset.save_to_disk(dataset_path)

with assert_arrow_memory_increases() if expected_in_memory else assert_arrow_memory_doesnt_increase():
Expand Down

0 comments on commit e84ca7c

Please sign in to comment.