Skip to content

Commit

Permalink
fix root type annotations for remaining datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Mar 18, 2024
1 parent a728ed7 commit 01977a1
Show file tree
Hide file tree
Showing 45 changed files with 127 additions and 99 deletions.
13 changes: 6 additions & 7 deletions torchvision/datasets/_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset


T1 = Tuple[Image.Image, Image.Image, Optional[np.ndarray], Optional[np.ndarray]]
T2 = Tuple[Image.Image, Image.Image, Optional[np.ndarray]]

Expand All @@ -33,7 +32,7 @@ class FlowDataset(ABC, VisionDataset):
# and it's up to whatever consumes the dataset to decide what valid_flow_mask should be.
_has_builtin_flow_mask = False

def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:

super().__init__(root=root)
self.transforms = transforms
Expand Down Expand Up @@ -125,7 +124,7 @@ class Sintel(FlowDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
Expand Down Expand Up @@ -191,7 +190,7 @@ class KittiFlow(FlowDataset):

_has_builtin_flow_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -256,7 +255,7 @@ class FlyingChairs(FlowDataset):
return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`.
"""

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "val"))
Expand Down Expand Up @@ -329,7 +328,7 @@ class FlyingThings3D(FlowDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
pass_name: str = "clean",
camera: str = "left",
Expand Down Expand Up @@ -419,7 +418,7 @@ class HD1K(FlowDataset):

_has_builtin_flow_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down
24 changes: 12 additions & 12 deletions torchvision/datasets/_stereo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class StereoMatchingDataset(ABC, VisionDataset):

_has_built_in_disparity_mask = False

def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
"""
Args:
root(str): Root directory of the dataset.
Expand Down Expand Up @@ -163,7 +163,7 @@ class CarlaStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "carla-highres"
Expand Down Expand Up @@ -240,7 +240,7 @@ class Kitti2012Stereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -328,7 +328,7 @@ class Kitti2015Stereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down Expand Up @@ -480,7 +480,7 @@ class Middlebury2014Stereo(StereoMatchingDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
calibration: Optional[str] = "perfect",
use_ambient_views: bool = False,
Expand Down Expand Up @@ -576,7 +576,7 @@ def _read_disparity(self, file_path: str) -> Union[Tuple[None, None], Tuple[np.n
valid_mask = (disparity_map > 0).squeeze(0) # mask out invalid disparities
return disparity_map, valid_mask

def _download_dataset(self, root: str) -> None:
def _download_dataset(self, root: Union[str, Path]) -> None:
base_url = "https://vision.middlebury.edu/stereo/data/scenes2014/zip"
# train and additional splits have 2 different calibration settings
root = Path(root) / "Middlebury2014"
Expand Down Expand Up @@ -675,7 +675,7 @@ class CREStereo(StereoMatchingDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
transforms: Optional[Callable] = None,
) -> None:
super().__init__(root, transforms)
Expand Down Expand Up @@ -762,7 +762,7 @@ class FallingThingsStereo(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, variant: str = "single", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], variant: str = "single", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "FallingThings"
Expand Down Expand Up @@ -877,7 +877,7 @@ class SceneFlowStereo(StereoMatchingDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
variant: str = "FlyingThings3D",
pass_name: str = "clean",
transforms: Optional[Callable] = None,
Expand Down Expand Up @@ -980,7 +980,7 @@ class SintelStereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], pass_name: str = "final", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(pass_name, "pass_name", valid_values=("final", "clean", "both"))
Expand Down Expand Up @@ -1087,7 +1087,7 @@ class InStereo2k(StereoMatchingDataset):
transforms (callable, optional): A function/transform that takes in a sample and returns a transformed version.
"""

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

root = Path(root) / "InStereo2k" / split
Expand Down Expand Up @@ -1176,7 +1176,7 @@ class ETH3DStereo(StereoMatchingDataset):

_has_built_in_disparity_mask = True

def __init__(self, root: str, split: str = "train", transforms: Optional[Callable] = None) -> None:
def __init__(self, root: Union[str, Path], split: str = "train", transforms: Optional[Callable] = None) -> None:
super().__init__(root, transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))
Expand Down
3 changes: 2 additions & 1 deletion torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import os.path
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

from PIL import Image
Expand Down Expand Up @@ -38,7 +39,7 @@ class Caltech101(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
target_type: Union[List[str], str] = "category",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down
3 changes: 2 additions & 1 deletion torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import csv
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

import PIL
Expand Down Expand Up @@ -63,7 +64,7 @@ class CelebA(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
target_type: Union[List[str], str] = "attr",
transform: Optional[Callable] = None,
Expand Down
5 changes: 3 additions & 2 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os.path
import pickle
from typing import Any, Callable, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -50,7 +51,7 @@ class CIFAR10(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down
3 changes: 2 additions & 1 deletion torchvision/datasets/cityscapes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from PIL import Image
Expand Down Expand Up @@ -103,7 +104,7 @@ class Cityscapes(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
mode: str = "fine",
target_type: Union[List[str], str] = "instance",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/clevr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import pathlib
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union
from urllib.parse import urlparse

from PIL import Image
Expand Down Expand Up @@ -30,7 +30,7 @@ class CLEVRClassification(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, pathlib.Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down
5 changes: 3 additions & 2 deletions torchvision/datasets/coco.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os.path
from typing import Any, Callable, List, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union

from PIL import Image

Expand All @@ -24,7 +25,7 @@ class CocoDetection(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
annFile: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/country211.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, Optional, Union

from .folder import ImageFolder
from .utils import download_and_extract_archive, verify_str_arg
Expand Down Expand Up @@ -28,7 +28,7 @@ class Country211(ImageFolder):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/dtd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import pathlib
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union

import PIL.Image

Expand Down Expand Up @@ -34,7 +34,7 @@ class DTD(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, pathlib.Path],
split: str = "train",
partition: int = 1,
transform: Optional[Callable] = None,
Expand Down
5 changes: 3 additions & 2 deletions torchvision/datasets/eurosat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Callable, Optional
from pathlib import Path
from typing import Callable, Optional, Union

from .folder import ImageFolder
from .utils import download_and_extract_archive
Expand All @@ -21,7 +22,7 @@ class EuroSAT(ImageFolder):

def __init__(
self,
root: str,
root: Union[str, Path],
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/fer2013.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import csv
import pathlib
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union

import torch
from PIL import Image
Expand Down Expand Up @@ -29,7 +29,7 @@ class FER2013(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, pathlib.Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down
5 changes: 3 additions & 2 deletions torchvision/datasets/fgvc_aircraft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import os
from typing import Any, Callable, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union

import PIL.Image

Expand Down Expand Up @@ -41,7 +42,7 @@ class FGVCAircraft(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "trainval",
annotation_level: str = "variant",
transform: Optional[Callable] = None,
Expand Down
3 changes: 1 addition & 2 deletions torchvision/datasets/flickr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
from collections import defaultdict
from html.parser import HTMLParser

from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -68,7 +67,7 @@ class Flickr8k(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
ann_file: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/flowers102.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple, Union

import PIL.Image

Expand Down Expand Up @@ -42,7 +42,7 @@ class Flowers102(VisionDataset):

def __init__(
self,
root: str,
root: Union[str, Path],
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
Expand Down
1 change: 0 additions & 1 deletion torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import os.path

from pathlib import Path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union

Expand Down
Loading

0 comments on commit 01977a1

Please sign in to comment.