From 01977a187961243eaf66ead246fd92a5c6c7ccfc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 18 Mar 2024 08:54:41 +0100 Subject: [PATCH] fix root type annotations for remaining datasets --- torchvision/datasets/_optical_flow.py | 13 ++++++------- torchvision/datasets/_stereo_matching.py | 24 ++++++++++++------------ torchvision/datasets/caltech.py | 3 ++- torchvision/datasets/celeba.py | 3 ++- torchvision/datasets/cifar.py | 5 +++-- torchvision/datasets/cityscapes.py | 3 ++- torchvision/datasets/clevr.py | 4 ++-- torchvision/datasets/coco.py | 5 +++-- torchvision/datasets/country211.py | 4 ++-- torchvision/datasets/dtd.py | 4 ++-- torchvision/datasets/eurosat.py | 5 +++-- torchvision/datasets/fer2013.py | 4 ++-- torchvision/datasets/fgvc_aircraft.py | 5 +++-- torchvision/datasets/flickr.py | 3 +-- torchvision/datasets/flowers102.py | 4 ++-- torchvision/datasets/folder.py | 1 - torchvision/datasets/food101.py | 4 ++-- torchvision/datasets/gtsrb.py | 4 ++-- torchvision/datasets/hmdb51.py | 5 +++-- torchvision/datasets/imagenet.py | 3 +-- torchvision/datasets/imagenette.py | 4 ++-- torchvision/datasets/inaturalist.py | 3 ++- torchvision/datasets/kinetics.py | 5 +++-- torchvision/datasets/kitti.py | 5 +++-- torchvision/datasets/lfw.py | 3 ++- torchvision/datasets/lsun.py | 3 ++- torchvision/datasets/mnist.py | 9 +++++---- torchvision/datasets/moving_mnist.py | 5 +++-- torchvision/datasets/omniglot.py | 5 +++-- torchvision/datasets/oxford_iiit_pet.py | 2 +- torchvision/datasets/pcam.py | 4 ++-- torchvision/datasets/phototour.py | 8 +++++++- torchvision/datasets/places365.py | 5 +++-- torchvision/datasets/rendered_sst2.py | 4 ++-- torchvision/datasets/sbd.py | 5 +++-- torchvision/datasets/sbu.py | 5 +++-- torchvision/datasets/semeion.py | 5 +++-- torchvision/datasets/stanford_cars.py | 4 ++-- torchvision/datasets/stl10.py | 5 +++-- torchvision/datasets/sun397.py | 4 ++-- torchvision/datasets/svhn.py | 5 +++-- torchvision/datasets/ucf101.py | 5 +++-- torchvision/datasets/usps.py | 5 +++-- torchvision/datasets/voc.py | 8 ++++---- torchvision/datasets/widerface.py | 4 +++- 45 files changed, 127 insertions(+), 99 deletions(-) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 5304f8b881a..40d25583942 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -13,7 +13,6 @@ from .utils import _read_pfm, verify_str_arg from .vision import VisionDataset - T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]] T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]] @@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset): # and it's up to whatever consumes the dataset to decide what valid_flow_mask should be. _has_builtin_flow_mask = False - def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None: super().__init__(root=root) self.transforms = transforms @@ -125,7 +124,7 @@ class Sintel(FlowDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", pass_name: str = "clean", transforms: Optional[Callable] = None, @@ -191,7 +190,7 @@ class KittiFlow(FlowDataset): _has_builtin_flow_mask = True - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -256,7 +255,7 @@ class FlyingChairs(FlowDataset): return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. """ - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "val")) @@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", pass_name: str = "clean", camera: str = "left", @@ -419,7 +418,7 @@ class HD1K(FlowDataset): _has_builtin_flow_mask = True - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root=root, transforms=transforms) verify_str_arg(split, "split", valid_values=("train", "test")) diff --git a/torchvision/datasets/_stereo_matching.py b/torchvision/datasets/_stereo_matching.py index 30444a07c1a..6a3f563a2da 100644 --- a/torchvision/datasets/_stereo_matching.py +++ b/torchvision/datasets/_stereo_matching.py @@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset): _has_built_in_disparity_mask = False - def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None: """ Args: root(str): Root directory of the dataset. @@ -163,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset): transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. """ - def __init__(self, root: str, transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) root = Path(root) / "carla-highres" @@ -240,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset): _has_built_in_disparity_mask = True - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -328,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset): _has_built_in_disparity_mask = True - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) verify_str_arg(split, "split", valid_values=("train", "test")) @@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", calibration: Optional[str] = "perfect", use_ambient_views: bool = False, @@ -576,7 +576,7 @@ def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.n valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities return disparity_map, valid_mask - def _download_dataset(self, root: str) -> None: + def _download_dataset(self, root: Union[str, Path]) -> None: base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip" # train and additional splits have 2 different calibration settings root = Path(root) / "Middlebury2014" @@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset): def __init__( self, - root: str, + root: Union[str, Path], transforms: Optional[Callable] = None, ) -> None: super().__init__(root, transforms) @@ -762,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset): transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. """ - def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) root = Path(root) / "FallingThings" @@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset): def __init__( self, - root: str, + root: Union[str, Path], variant: str = "FlyingThings3D", pass_name: str = "clean", transforms: Optional[Callable] = None, @@ -980,7 +980,7 @@ class SintelStereo(StereoMatchingDataset): _has_built_in_disparity_mask = True - def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both")) @@ -1087,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset): transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version. """ - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) root = Path(root) / "InStereo2k" / split @@ -1176,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset): _has_built_in_disparity_mask = True - def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None: + def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None: super().__init__(root, transforms) verify_str_arg(split, "split", valid_values=("train", "test")) diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index 21ff1d9d38f..fe4f0fad208 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -1,5 +1,6 @@ import os import os.path +from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union from PIL import Image @@ -38,7 +39,7 @@ class Caltech101(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], target_type: Union[List[str], str] = "category", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index 7b39be2ddca..147597d3ab3 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -1,6 +1,7 @@ import csv import os from collections import namedtuple +from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union import PIL @@ -63,7 +64,7 @@ class CelebA(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", target_type: Union[List[str], str] = "attr", transform: Optional[Callable] = None, diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index 741b94b5b4c..1637670ab91 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -1,6 +1,7 @@ import os.path import pickle -from typing import Any, Callable, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union import numpy as np from PIL import Image @@ -50,7 +51,7 @@ class CIFAR10(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/cityscapes.py b/torchvision/datasets/cityscapes.py index 51d0748e8e2..6f7281f2574 100644 --- a/torchvision/datasets/cityscapes.py +++ b/torchvision/datasets/cityscapes.py @@ -1,6 +1,7 @@ import json import os from collections import namedtuple +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from PIL import Image @@ -103,7 +104,7 @@ class Cityscapes(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", mode: str = "fine", target_type: Union[List[str], str] = "instance", diff --git a/torchvision/datasets/clevr.py b/torchvision/datasets/clevr.py index 5ca9254e45f..328eb7d79da 100644 --- a/torchvision/datasets/clevr.py +++ b/torchvision/datasets/clevr.py @@ -1,6 +1,6 @@ import json import pathlib -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union from urllib.parse import urlparse from PIL import Image @@ -30,7 +30,7 @@ class CLEVRClassification(VisionDataset): def __init__( self, - root: str, + root: Union[str, pathlib.Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index 7359c1887c4..f3b7be798b2 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -1,5 +1,6 @@ import os.path -from typing import Any, Callable, List, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union from PIL import Image @@ -24,7 +25,7 @@ class CocoDetection(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], annFile: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index 8c53f8ae25f..a0f82ee1226 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Callable, Optional +from typing import Callable, Optional, Union from .folder import ImageFolder from .utils import download_and_extract_archive, verify_str_arg @@ -28,7 +28,7 @@ class Country211(ImageFolder): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index 4abdf97cdcf..71c556bd201 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -1,6 +1,6 @@ import os import pathlib -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import PIL.Image @@ -34,7 +34,7 @@ class DTD(VisionDataset): def __init__( self, - root: str, + root: Union[str, pathlib.Path], split: str = "train", partition: int = 1, transform: Optional[Callable] = None, diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index fa9881616d8..3f490b11902 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -1,5 +1,6 @@ import os -from typing import Callable, Optional +from pathlib import Path +from typing import Callable, Optional, Union from .folder import ImageFolder from .utils import download_and_extract_archive @@ -21,7 +22,7 @@ class EuroSAT(ImageFolder): def __init__( self, - root: str, + root: Union[str, Path], transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, diff --git a/torchvision/datasets/fer2013.py b/torchvision/datasets/fer2013.py index e073b4d4cd0..057fe695a13 100644 --- a/torchvision/datasets/fer2013.py +++ b/torchvision/datasets/fer2013.py @@ -1,6 +1,6 @@ import csv import pathlib -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import torch from PIL import Image @@ -29,7 +29,7 @@ class FER2013(VisionDataset): def __init__( self, - root: str, + root: Union[str, pathlib.Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/fgvc_aircraft.py b/torchvision/datasets/fgvc_aircraft.py index 89538b38b6f..bbf4e970a78 100644 --- a/torchvision/datasets/fgvc_aircraft.py +++ b/torchvision/datasets/fgvc_aircraft.py @@ -1,7 +1,8 @@ from __future__ import annotations import os -from typing import Any, Callable, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union import PIL.Image @@ -41,7 +42,7 @@ class FGVCAircraft(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "trainval", annotation_level: str = "variant", transform: Optional[Callable] = None, diff --git a/torchvision/datasets/flickr.py b/torchvision/datasets/flickr.py index d542d3c5ced..1021309db05 100644 --- a/torchvision/datasets/flickr.py +++ b/torchvision/datasets/flickr.py @@ -2,7 +2,6 @@ import os from collections import defaultdict from html.parser import HTMLParser - from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -68,7 +67,7 @@ class Flickr8k(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], ann_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/flowers102.py b/torchvision/datasets/flowers102.py index 1e5d74e9561..07f403702f5 100644 --- a/torchvision/datasets/flowers102.py +++ b/torchvision/datasets/flowers102.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import PIL.Image @@ -42,7 +42,7 @@ class Flowers102(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 756f160e28f..9ee06b6a650 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,6 +1,5 @@ import os import os.path - from pathlib import Path from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union diff --git a/torchvision/datasets/food101.py b/torchvision/datasets/food101.py index ac0ddc0da8a..f734787c1bf 100644 --- a/torchvision/datasets/food101.py +++ b/torchvision/datasets/food101.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import PIL.Image @@ -34,7 +34,7 @@ class Food101(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/gtsrb.py b/torchvision/datasets/gtsrb.py index 340a0b53754..a3d012c70b2 100644 --- a/torchvision/datasets/gtsrb.py +++ b/torchvision/datasets/gtsrb.py @@ -1,6 +1,6 @@ import csv import pathlib -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import PIL @@ -25,7 +25,7 @@ class GTSRB(VisionDataset): def __init__( self, - root: str, + root: Union[str, pathlib.Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/hmdb51.py b/torchvision/datasets/hmdb51.py index e0b322ca6c4..8377e40d57c 100644 --- a/torchvision/datasets/hmdb51.py +++ b/torchvision/datasets/hmdb51.py @@ -1,6 +1,7 @@ import glob import os -from typing import Any, Callable, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from torch import Tensor @@ -59,7 +60,7 @@ class HMDB51(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], annotation_path: str, frames_per_clip: int, step_between_clips: int = 1, diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 5548840ee2b..d7caf328d2b 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -2,7 +2,6 @@ import shutil import tempfile from contextlib import contextmanager - from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Tuple, Union @@ -47,7 +46,7 @@ class ImageNet(ImageFolder): targets (list): The class_index value for each image in the dataset """ - def __init__(self, root: str, split: str = "train", **kwargs: Any) -> None: + def __init__(self, root: Union[str, Path], split: str = "train", **kwargs: Any) -> None: root = self.root = os.path.expanduser(root) self.split = verify_str_arg(split, "split", ("train", "val")) diff --git a/torchvision/datasets/imagenette.py b/torchvision/datasets/imagenette.py index dec92515b48..05da537891b 100644 --- a/torchvision/datasets/imagenette.py +++ b/torchvision/datasets/imagenette.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union from PIL import Image @@ -48,7 +48,7 @@ class Imagenette(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", size: str = "full", download=False, diff --git a/torchvision/datasets/inaturalist.py b/torchvision/datasets/inaturalist.py index d06b4376e01..68f9a77f56a 100644 --- a/torchvision/datasets/inaturalist.py +++ b/torchvision/datasets/inaturalist.py @@ -1,5 +1,6 @@ import os import os.path +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from PIL import Image @@ -65,7 +66,7 @@ class INaturalist(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], version: str = "2021_train", target_type: Union[List[str], str] = "full", transform: Optional[Callable] = None, diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index d26017c89d3..42d32533953 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -5,7 +5,8 @@ from functools import partial from multiprocessing import Pool from os import path -from typing import Any, Callable, Dict, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple, Union from torch import Tensor @@ -90,7 +91,7 @@ class Kinetics(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], frames_per_clip: int, num_classes: str = "400", split: str = "train", diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py index 42a3054383c..69e603c76f2 100644 --- a/torchvision/datasets/kitti.py +++ b/torchvision/datasets/kitti.py @@ -1,6 +1,7 @@ import csv import os -from typing import Any, Callable, List, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union from PIL import Image @@ -51,7 +52,7 @@ class Kitti(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/lfw.py b/torchvision/datasets/lfw.py index e1c971c36ea..69f1edaf72f 100644 --- a/torchvision/datasets/lfw.py +++ b/torchvision/datasets/lfw.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union from PIL import Image @@ -31,7 +32,7 @@ class _LFW(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str, image_set: str, view: str, diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index 2eb67a35bda..a2f5e18b991 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -3,6 +3,7 @@ import pickle import string from collections.abc import Iterable +from pathlib import Path from typing import Any, Callable, cast, List, Optional, Tuple, Union from PIL import Image @@ -71,7 +72,7 @@ class LSUN(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], classes: Union[str, List[str]] = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 4ca00f61896..a2389d598e6 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -5,7 +5,8 @@ import string import sys import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib.error import URLError import numpy as np @@ -82,7 +83,7 @@ def test_data(self): def __init__( self, - root: str, + root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -290,7 +291,7 @@ class EMNIST(MNIST): "mnist": list(string.digits), } - def __init__(self, root: str, split: str, **kwargs: Any) -> None: + def __init__(self, root: Union[str, Path], split: str, **kwargs: Any) -> None: self.split = verify_str_arg(split, "split", self.splits) self.training_file = self._training_file(split) self.test_file = self._test_file(split) @@ -416,7 +417,7 @@ class QMNIST(MNIST): ] def __init__( - self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any + self, root: Union[str, Path], what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any ) -> None: if what is None: what = "train" if train else "test" diff --git a/torchvision/datasets/moving_mnist.py b/torchvision/datasets/moving_mnist.py index f9c848cb0a2..d02811762b8 100644 --- a/torchvision/datasets/moving_mnist.py +++ b/torchvision/datasets/moving_mnist.py @@ -1,5 +1,6 @@ import os.path -from typing import Callable, Optional +from pathlib import Path +from typing import Callable, Optional, Union import numpy as np import torch @@ -28,7 +29,7 @@ class MovingMNIST(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: Optional[str] = None, split_ratio: int = 10, download: bool = False, diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 25d76917011..c02cf91234a 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -1,5 +1,6 @@ from os.path import join -from typing import Any, Callable, List, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union from PIL import Image @@ -33,7 +34,7 @@ class Omniglot(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], background: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/oxford_iiit_pet.py b/torchvision/datasets/oxford_iiit_pet.py index 3ba33209cda..9fe78901626 100644 --- a/torchvision/datasets/oxford_iiit_pet.py +++ b/torchvision/datasets/oxford_iiit_pet.py @@ -38,7 +38,7 @@ class OxfordIIITPet(VisionDataset): def __init__( self, - root: str, + root: Union[str, pathlib.Path], split: str = "trainval", target_types: Union[Sequence[str], str] = "category", transforms: Optional[Callable] = None, diff --git a/torchvision/datasets/pcam.py b/torchvision/datasets/pcam.py index 57f60738aad..8849e0ea39d 100644 --- a/torchvision/datasets/pcam.py +++ b/torchvision/datasets/pcam.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union from PIL import Image @@ -72,7 +72,7 @@ class PCAM(VisionDataset): def __init__( self, - root: str, + root: Union[str, pathlib.Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/phototour.py b/torchvision/datasets/phototour.py index 583299e2010..fd2466a3d36 100644 --- a/torchvision/datasets/phototour.py +++ b/torchvision/datasets/phototour.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np @@ -87,7 +88,12 @@ class PhotoTour(VisionDataset): matches_files = "m50_100000_100000_0.txt" def __init__( - self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False + self, + root: Union[str, Path], + name: str, + train: bool = True, + transform: Optional[Callable] = None, + download: bool = False, ) -> None: super().__init__(root, transform=transform) self.name = name diff --git a/torchvision/datasets/places365.py b/torchvision/datasets/places365.py index 29b014b509b..98966e1dc2f 100644 --- a/torchvision/datasets/places365.py +++ b/torchvision/datasets/places365.py @@ -1,6 +1,7 @@ import os from os import path -from typing import Any, Callable, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib.parse import urljoin from .folder import default_loader @@ -62,7 +63,7 @@ class Places365(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train-standard", small: bool = False, download: bool = False, diff --git a/torchvision/datasets/rendered_sst2.py b/torchvision/datasets/rendered_sst2.py index f536ce43681..48b0ddfc4fb 100644 --- a/torchvision/datasets/rendered_sst2.py +++ b/torchvision/datasets/rendered_sst2.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import PIL.Image @@ -35,7 +35,7 @@ class RenderedSST2(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/sbd.py b/torchvision/datasets/sbd.py index 12b018beed3..7f245675b2d 100644 --- a/torchvision/datasets/sbd.py +++ b/torchvision/datasets/sbd.py @@ -1,6 +1,7 @@ import os import shutil -from typing import Any, Callable, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union import numpy as np from PIL import Image @@ -51,7 +52,7 @@ class SBDataset(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], image_set: str = "train", mode: str = "boundaries", download: bool = False, diff --git a/torchvision/datasets/sbu.py b/torchvision/datasets/sbu.py index 1b3088bc061..3c349370a12 100644 --- a/torchvision/datasets/sbu.py +++ b/torchvision/datasets/sbu.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union from PIL import Image @@ -28,7 +29,7 @@ class SBU(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = True, diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index 53e048a4755..d0344c74241 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -1,5 +1,6 @@ import os.path -from typing import Any, Callable, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union import numpy as np from PIL import Image @@ -29,7 +30,7 @@ class SEMEION(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = True, diff --git a/torchvision/datasets/stanford_cars.py b/torchvision/datasets/stanford_cars.py index 0e0537f454f..20f95d9c8ce 100644 --- a/torchvision/datasets/stanford_cars.py +++ b/torchvision/datasets/stanford_cars.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union from PIL import Image @@ -35,7 +35,7 @@ class StanfordCars(VisionDataset): def __init__( self, - root: str, + root: Union[str, pathlib.Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 2cb24a8d077..90ff41738eb 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -1,5 +1,6 @@ import os.path -from typing import Any, Callable, cast, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, cast, Optional, Tuple, Union import numpy as np from PIL import Image @@ -45,7 +46,7 @@ class STL10(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", folds: Optional[int] = None, transform: Optional[Callable] = None, diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index 619ae97b2a7..4db0a3cf237 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import PIL.Image @@ -28,7 +28,7 @@ class SUN397(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, diff --git a/torchvision/datasets/svhn.py b/torchvision/datasets/svhn.py index 292d20d5eb2..5d20d7db7e3 100644 --- a/torchvision/datasets/svhn.py +++ b/torchvision/datasets/svhn.py @@ -1,5 +1,6 @@ import os.path -from typing import Any, Callable, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union import numpy as np from PIL import Image @@ -52,7 +53,7 @@ class SVHN(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/ucf101.py b/torchvision/datasets/ucf101.py index 3f2fddedfcd..935f8ad41c7 100644 --- a/torchvision/datasets/ucf101.py +++ b/torchvision/datasets/ucf101.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, Dict, List, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from torch import Tensor @@ -52,7 +53,7 @@ class UCF101(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], annotation_path: str, frames_per_clip: int, step_between_clips: int = 1, diff --git a/torchvision/datasets/usps.py b/torchvision/datasets/usps.py index 51ff8022e17..9c681e79f6c 100644 --- a/torchvision/datasets/usps.py +++ b/torchvision/datasets/usps.py @@ -1,5 +1,6 @@ import os -from typing import Any, Callable, Optional, Tuple +from pathlib import Path +from typing import Any, Callable, Optional, Tuple, Union import numpy as np from PIL import Image @@ -43,7 +44,7 @@ class USPS(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index feb2b427275..0f0e84c84fa 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -1,18 +1,18 @@ import collections import os +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from xml.etree.ElementTree import Element as ET_Element -from .vision import VisionDataset - try: from defusedxml.ElementTree import parse as ET_parse except ImportError: from xml.etree.ElementTree import parse as ET_parse -from typing import Any, Callable, Dict, List, Optional, Tuple from PIL import Image from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset DATASET_YEAR_DICT = { "2012": { @@ -67,7 +67,7 @@ class _VOCBase(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], year: str = "2012", image_set: str = "train", download: bool = False, diff --git a/torchvision/datasets/widerface.py b/torchvision/datasets/widerface.py index e2d4ef87037..90f80b7175b 100644 --- a/torchvision/datasets/widerface.py +++ b/torchvision/datasets/widerface.py @@ -1,5 +1,7 @@ import os from os.path import abspath, expanduser +from pathlib import Path + from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -55,7 +57,7 @@ class WIDERFace(VisionDataset): def __init__( self, - root: str, + root: Union[str, Path], split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,