Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document that datasets support pathlib.Path #8321

Merged
merged 4 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions torchvision/datasets/_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset


T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]

Expand All @@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset):
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False

def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:

super().__init__(root=root)
self.transforms = transforms
Expand Down Expand Up @@ -113,7 +112,7 @@ class Sintel(FlowDataset):
...

Args:
root (string): Root directory of the Sintel Dataset.
root (str or ``pathlib.Path``): Root directory of the Sintel Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default), "final", or "both". See link above for
details on the different passes.
Expand All @@ -125,7 +124,7 @@ class Sintel(FlowDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
Expand Down Expand Up @@ -183,15 +182,15 @@ class KittiFlow(FlowDataset):
flow_occ

Args:
root (string): Root directory of the KittiFlow Dataset.
root (str or ``pathlib.Path``): Root directory of the KittiFlow Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
"""

_has_builtin_flow_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -248,15 +247,15 @@ class FlyingChairs(FlowDataset):


Args:
root (string): Root directory of the FlyingChairs Dataset.
root (str or ``pathlib.Path``): Root directory of the FlyingChairs Dataset.
split (string, optional): The dataset split, either "train" (default) or "val"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
``valid_flow_mask`` is expected for consistency with other datasets which
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "val"))
Expand Down Expand Up @@ -316,7 +315,7 @@ class FlyingThings3D(FlowDataset):
TRAIN

Args:
root (string): Root directory of the intel FlyingThings3D Dataset.
root (str or ``pathlib.Path``): Root directory of the intel FlyingThings3D Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for
details on the different passes.
Expand All @@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
pass_name: str = "clean",
camera: str = "left",
Expand Down Expand Up @@ -411,15 +410,15 @@ class HD1K(FlowDataset):
image_2

Args:
root (string): Root directory of the HD1K Dataset.
root (str or ``pathlib.Path``): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid_flow_mask`` and returns a transformed version.
"""

_has_builtin_flow_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down
42 changes: 21 additions & 21 deletions torchvision/datasets/_stereo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):

_has_built_in_disparity_mask = False

def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
"""
Args:
root(str): Root directory of the dataset.
Expand Down Expand Up @@ -159,11 +159,11 @@ class CarlaStereo(StereoMatchingDataset):
...

Args:
root (string): Root directory where `carla-highres` is located.
root (str or ``pathlib.Path``): Root directory where `carla-highres` is located.
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "carla-highres"
Expand Down Expand Up @@ -233,14 +233,14 @@ class Kitti2012Stereo(StereoMatchingDataset):
calib

Args:
root (string): Root directory where `Kitti2012` is located.
root (str or ``pathlib.Path``): Root directory where `Kitti2012` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -321,14 +321,14 @@ class Kitti2015Stereo(StereoMatchingDataset):
calib

Args:
root (string): Root directory where `Kitti2015` is located.
root (str or ``pathlib.Path``): Root directory where `Kitti2015` is located.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -420,7 +420,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):
...

Args:
root (string): Root directory of the Middleburry 2014 Dataset.
root (str or ``pathlib.Path``): Root directory of the Middleburry 2014 Dataset.
split (string, optional): The dataset split of scenes, either "train" (default), "test", or "additional"
use_ambient_views (boolean, optional): Whether to use different expose or lightning views when possible.
The dataset samples with equal probability between ``[im1.png, im1E.png, im1L.png]``.
Expand Down Expand Up @@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
calibration: Optional[str] = "perfect",
use_ambient_views: bool = False,
Expand Down Expand Up @@ -576,7 +576,7 @@ def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.n
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
return disparity_map, valid_mask

def _download_dataset(self, root: str) -> None:
def _download_dataset(self, root: Union[str, Path]) -> None:
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings
root = Path(root) / "Middlebury2014"
Expand Down Expand Up @@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root, transforms)
Expand Down Expand Up @@ -757,12 +757,12 @@ class FallingThingsStereo(StereoMatchingDataset):
...

Args:
root (string): Root directory where FallingThings is located.
root (str or ``pathlib.Path``): Root directory where FallingThings is located.
variant (string): Which variant to use. Either "single", "mixed", or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "FallingThings"
Expand Down Expand Up @@ -868,7 +868,7 @@ class SceneFlowStereo(StereoMatchingDataset):
...

Args:
root (string): Root directory where SceneFlow is located.
root (str or ``pathlib.Path``): Root directory where SceneFlow is located.
variant (string): Which dataset variant to user, "FlyingThings3D" (default), "Monkaa" or "Driving".
pass_name (string): Which pass to use, "clean" (default), "final" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
Expand All @@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
variant: str = "FlyingThings3D",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
Expand Down Expand Up @@ -973,14 +973,14 @@ class SintelStereo(StereoMatchingDataset):
...

Args:
root (string): Root directory where Sintel Stereo is located.
root (str or ``pathlib.Path``): Root directory where Sintel Stereo is located.
pass_name (string): The name of the pass to use, either "final", "clean" or "both".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
Expand Down Expand Up @@ -1082,12 +1082,12 @@ class InStereo2k(StereoMatchingDataset):
...

Args:
root (string): Root directory where InStereo2k is located.
root (str or ``pathlib.Path``): Root directory where InStereo2k is located.
split (string): Either "train" or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "InStereo2k" / split
Expand Down Expand Up @@ -1169,14 +1169,14 @@ class ETH3DStereo(StereoMatchingDataset):
...

Args:
root (string): Root directory of the ETH3D Dataset.
root (str or ``pathlib.Path``): Root directory of the ETH3D Dataset.
split (string, optional): The dataset split of scenes, either "train" (default) or "test".
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down
7 changes: 4 additions & 3 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import os.path
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

from PIL import Image
Expand All @@ -16,7 +17,7 @@ class Caltech101(VisionDataset):
This class needs `scipy <https://docs.scipy.org/doc/>`_ to load target files from `.mat` format.

Args:
root (string): Root directory of dataset where directory
root (str or ``pathlib.Path``): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified
Expand All @@ -38,7 +39,7 @@ class Caltech101(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
target_type: Union[List[str], str] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down Expand Up @@ -153,7 +154,7 @@ class Caltech256(VisionDataset):
"""`Caltech 256 <https://data.caltech.edu/records/20087>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
root (str or ``pathlib.Path``): Root directory of dataset where directory
``caltech256`` exists or will be saved to if download is set to True.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
Expand Down
5 changes: 3 additions & 2 deletions torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

import PIL
Expand All @@ -16,7 +17,7 @@ class CelebA(VisionDataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.

Args:
root (string): Root directory where images are downloaded to.
root (str or ``pathlib.Path``): Root directory where images are downloaded to.
split (string): One of {'train', 'valid', 'test', 'all'}.
Accordingly dataset is selected.
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
Expand Down Expand Up @@ -63,7 +64,7 @@ class CelebA(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
target_type: Union[List[str], str] = "attr",
transform: Optional[Callable] = None,
Expand Down
7 changes: 4 additions & 3 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path
import pickle
from typing import Any, Callable, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

import numpy as np
from PIL import Image
Expand All @@ -13,7 +14,7 @@ class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
root (str or ``pathlib.Path``): Root directory of dataset where directory
``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
Expand Down Expand Up @@ -50,7 +51,7 @@ class CIFAR10(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down
5 changes: 3 additions & 2 deletions torchvision/datasets/cityscapes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from PIL import Image
Expand All @@ -13,7 +14,7 @@ class Cityscapes(VisionDataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.

Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
root (str or ``pathlib.Path``): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val``
Expand Down Expand Up @@ -103,7 +104,7 @@ class Cityscapes(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
mode: str = "fine",
target_type: Union[List[str], str] = "instance",
Expand Down
Loading
Loading