diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 87549c9719f..d5af9d67434 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -93,6 +93,12 @@ iNaturalist .. autoclass:: INaturalist +LandCover.ai Geo +^^^^^^^^^^^^^^^^ + +.. autoclass:: LandCoverAIBase +.. autoclass:: LandCoverAIGeo + Landsat ^^^^^^^ diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv index 9bee384cd95..66b9245b865 100644 --- a/docs/api/geo_datasets.csv +++ b/docs/api/geo_datasets.csv @@ -12,6 +12,7 @@ Dataset,Type,Source,Size (px),Resolution (m) `GBIF`_,Points,Citizen Scientists,-,- `GlobBiomass`_,Masks,Landsat,"45,000x45,000",100 `iNaturalist`_,Points,Citizen Scientists,-,- +`LandCover.ai Geo`_,"Imagery, Masks",Aerial,"4,200--9,500",0.25--0.5 `Landsat`_,Imagery,Landsat,"8,900x8,900",30 `NAIP`_,Imagery,Aerial,"6,100x7,600",1 `Open Buildings`_,Geometries,"Maxar, CNES/Airbus",-,- diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py index e8e64680ef0..c4f11e05b61 100644 --- a/tests/datasets/test_landcoverai.py +++ b/tests/datasets/test_landcoverai.py @@ -14,13 +14,62 @@ from torch.utils.data import ConcatDataset import torchgeo.datasets.utils -from torchgeo.datasets import LandCoverAI +from torchgeo.datasets import BoundingBox, LandCoverAI, LandCoverAIGeo def download_url(url: str, root: str, *args: str, **kwargs: str) -> None: shutil.copy(url, root) +class TestLandCoverAIGeo: + @pytest.fixture + def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> LandCoverAIGeo: + monkeypatch.setattr(torchgeo.datasets.landcoverai, "download_url", download_url) + md5 = "ff8998857cc8511f644d3f7d0f3688d0" + monkeypatch.setattr(LandCoverAIGeo, "md5", md5) + url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip") + monkeypatch.setattr(LandCoverAIGeo, "url", url) + root = str(tmp_path) + transforms = nn.Identity() + return LandCoverAIGeo(root, transforms=transforms, download=True, checksum=True) + + def test_getitem(self, dataset: LandCoverAIGeo) -> None: + x = dataset[dataset.bounds] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["mask"], torch.Tensor) + + def test_already_extracted(self, dataset: LandCoverAIGeo) -> None: + LandCoverAIGeo(root=dataset.root, download=True) + + def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None: + url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip") + root = str(tmp_path) + shutil.copy(url, root) + LandCoverAIGeo(root) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found"): + LandCoverAIGeo(str(tmp_path)) + + def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None: + query = BoundingBox(0, 0, 0, 0, 0, 0) + with pytest.raises( + IndexError, match="query: .* not found in index with bounds:" + ): + dataset[query] + + def test_plot(self, dataset: LandCoverAIGeo) -> None: + x = dataset[dataset.bounds].copy() + dataset.plot(x, suptitle="Test") + plt.close() + dataset.plot(x, show_titles=False) + plt.close() + x["prediction"] = x["mask"][:, :, 0].clone().unsqueeze(2) + dataset.plot(x) + plt.close() + + class TestLandCoverAI: @pytest.fixture(params=["train", "val", "test"]) def dataset( diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 724fafad675..5bcf3e65195 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -53,7 +53,7 @@ from .idtrees import IDTReeS from .inaturalist import INaturalist from .inria import InriaAerialImageLabeling -from .landcoverai import LandCoverAI +from .landcoverai import LandCoverAI, LandCoverAIBase, LandCoverAIGeo from .landsat import ( Landsat, Landsat1, @@ -137,6 +137,8 @@ "GBIF", "GlobBiomass", "INaturalist", + "LandCoverAIBase", + "LandCoverAIGeo", "Landsat", "Landsat1", "Landsat2", diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py index c9e7c80c41f..aad853bb2f2 100644 --- a/torchgeo/datasets/landcoverai.py +++ b/torchgeo/datasets/landcoverai.py @@ -2,31 +2,33 @@ # Licensed under the MIT License. """LandCover.ai dataset.""" - +import abc import glob import hashlib import os from functools import lru_cache -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional, cast import matplotlib.pyplot as plt import numpy as np import torch from matplotlib.colors import ListedColormap from PIL import Image +from rasterio.crs import CRS from torch import Tensor +from torch.utils.data import Dataset -from .geo import NonGeoDataset -from .utils import download_url, extract_archive, working_dir +from .geo import NonGeoDataset, RasterDataset +from .utils import BoundingBox, download_url, extract_archive, working_dir -class LandCoverAI(NonGeoDataset): - r"""LandCover.ai dataset. +class LandCoverAIBase(Dataset[Dict[str, Any]], abc.ABC): + r"""Abstract base class for LandCover.ai Geo and NonGeo datasets. - The `LandCover.ai `__ (Land Cover from Aerial Imagery) - dataset is a dataset for automatic mapping of buildings, woodlands, water and - roads from aerial images. This implementation is specifically for Version 1 of - Landcover.ai. + The `LandCover.ai `__ (Land Cover from + Aerial Imagery) dataset is a dataset for automatic mapping of buildings, woodlands, + water and roads from aerial images. This implementation is specifically for + Version 1 of LandCover.ai. Dataset features: @@ -50,7 +52,241 @@ class LandCoverAI(NonGeoDataset): If you use this dataset in your research, please cite the following paper: - * https://arxiv.org/abs/2005.02264v3 + * https://arxiv.org/abs/2005.02264v4 + + .. versionadded:: 0.5 + """ + + url = "https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip" + filename = "landcover.ai.v1.zip" + md5 = "3268c89070e8734b4e91d531c0617e03" + classes = ["Background", "Building", "Woodland", "Water", "Road"] + cmap = { + 0: (0, 0, 0, 0), + 1: (97, 74, 74, 255), + 2: (38, 115, 0, 255), + 3: (0, 197, 255, 255), + 4: (207, 207, 207, 255), + } + + def __init__( + self, root: str = "data", download: bool = False, checksum: bool = False + ) -> None: + """Initialize a new LandCover.ai dataset instance. + + Args: + root: root directory where dataset can be found + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + cache: if True, cache file handle to speed up repeated sampling + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + self.root = root + self.download = download + self.checksum = checksum + + lc_colors = np.zeros((max(self.cmap.keys()) + 1, 4)) + lc_colors[list(self.cmap.keys())] = list(self.cmap.values()) + lc_colors = lc_colors[:, :3] / 255 + self._lc_cmap = ListedColormap(lc_colors) + + self._verify() + + def _verify(self) -> None: + """Verify the integrity of the dataset. + + Raises: + RuntimeError: if ``download=False`` but dataset is missing or checksum fails + """ + if self._verify_data(): + return + + # Check if the zip file has already been downloaded + pathname = os.path.join(self.root, self.filename) + if os.path.exists(pathname): + self._extract() + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + f"Dataset not found in `root={self.root}` and `download=False`, " + "either specify a different `root` directory or use `download=True` " + "to automatically download the dataset." + ) + + # Download the dataset + self._download() + self._extract() + + @abc.abstractmethod + def __getitem__(self, query: Any) -> Dict[str, Any]: + """Retrieve image, mask and metadata indexed by index. + + Args: + query: coordinates or an index + + Returns: + sample of image, mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + + @abc.abstractmethod + def _verify_data(self) -> bool: + """Verify if the images and masks are present.""" + + def _download(self) -> None: + """Download the dataset.""" + download_url(self.url, self.root, md5=self.md5 if self.checksum else None) + + def _extract(self) -> None: + """Extract the dataset.""" + extract_archive(os.path.join(self.root, self.filename)) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + """ + image = np.rollaxis(sample["image"].numpy().astype("uint8").squeeze(), 0, 3) + mask = sample["mask"].numpy().astype("uint8").squeeze() + + num_panels = 2 + showing_predictions = "prediction" in sample + if showing_predictions: + predictions = sample["prediction"].numpy() + num_panels += 1 + + fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) + axs[0].imshow(image) + axs[0].axis("off") + axs[1].imshow(mask, vmin=0, vmax=4, cmap=self._lc_cmap, interpolation="none") + axs[1].axis("off") + if show_titles: + axs[0].set_title("Image") + axs[1].set_title("Mask") + + if showing_predictions: + axs[2].imshow( + predictions, vmin=0, vmax=4, cmap=self._lc_cmap, interpolation="none" + ) + axs[2].axis("off") + if show_titles: + axs[2].set_title("Predictions") + + if suptitle is not None: + plt.suptitle(suptitle) + return fig + + +class LandCoverAIGeo(LandCoverAIBase, RasterDataset): + """LandCover.ai Geo dataset. + + See the abstract LandCoverAIBase class to find out more. + + .. versionadded:: 0.5 + """ + + filename_glob = os.path.join("images", "*.tif") + filename_regex = ".*tif" + + def __init__( + self, + root: str = "data", + crs: Optional[CRS] = None, + res: Optional[float] = None, + transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, + cache: bool = True, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new LandCover.ai NonGeo dataset instance. + + Args: + root: root directory where dataset can be found + crs: :term:`coordinate reference system (CRS)` to warp to + (defaults to the CRS of the first file found) + res: resolution of the dataset in units of CRS + (defaults to the resolution of the first file found) + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + cache: if True, cache file handle to speed up repeated sampling + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + LandCoverAIBase.__init__(self, root, download, checksum) + RasterDataset.__init__(self, root, crs, res, transforms=transforms, cache=cache) + + def _verify_data(self) -> bool: + """Verify if the images and masks are present.""" + img_query = os.path.join(self.root, "images", "*.tif") + mask_query = os.path.join(self.root, "masks", "*.tif") + images = glob.glob(img_query) + masks = glob.glob(mask_query) + return len(images) > 0 and len(images) == len(masks) + + def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: + """Retrieve image/mask and metadata indexed by query. + + Args: + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index + + Returns: + sample of image, mask and metadata at that index + + Raises: + IndexError: if query is not found in the index + """ + hits = self.index.intersection(tuple(query), objects=True) + img_filepaths = cast(List[str], [hit.object for hit in hits]) + mask_filepaths = [path.replace("images", "masks") for path in img_filepaths] + + if not img_filepaths: + raise IndexError( + f"query: {query} not found in index with bounds: {self.bounds}" + ) + + img = self._merge_files(img_filepaths, query, self.band_indexes) + mask = self._merge_files(mask_filepaths, query, self.band_indexes) + sample = { + "crs": self.crs, + "bbox": query, + "image": img.float(), + "mask": mask.long(), + } + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + +class LandCoverAI(LandCoverAIBase, NonGeoDataset): + """LandCover.ai dataset. + + See the abstract LandCoverAIBase class to find out more. .. note:: @@ -60,20 +296,7 @@ class LandCoverAI(NonGeoDataset): the train/val/test split """ - url = "https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip" - filename = "landcover.ai.v1.zip" - md5 = "3268c89070e8734b4e91d531c0617e03" sha256 = "15ee4ca9e3fd187957addfa8f0d74ac31bc928a966f76926e11b3c33ea76daa1" - classes = ["Background", "Building", "Woodland", "Water", "Road"] - cmap = ListedColormap( - [ - [0.63921569, 1.0, 0.45098039], - [0.61176471, 0.61176471, 0.61176471], - [0.14901961, 0.45098039, 0.0], - [0.0, 0.77254902, 1.0], - [0.0, 0.0, 0.0], - ] - ) def __init__( self, @@ -100,14 +323,10 @@ def __init__( """ assert split in ["train", "val", "test"] - self.root = root - self.split = split - self.transforms = transforms - self.download = download - self.checksum = checksum - - self._verify() + super().__init__(root, download, checksum) + self.transforms = transforms + self.split = split with open(os.path.join(self.root, split + ".txt")) as f: self.ids = f.readlines() @@ -170,39 +389,13 @@ def _load_target(self, id_: str) -> Tensor: tensor = torch.from_numpy(array).long() return tensor - def _verify(self) -> None: - """Verify the integrity of the dataset. - - Raises: - RuntimeError: if ``download=False`` but dataset is missing or checksum fails - """ - # Check if the extracted files already exist - jpg = os.path.join(self.root, "output", "*_*.jpg") - png = os.path.join(self.root, "output", "*_*_m.png") - if glob.glob(jpg) and glob.glob(png): - return - - # Check if the zip file has already been downloaded - pathname = os.path.join(self.root, self.filename) - if os.path.exists(pathname): - self._extract() - return - - # Check if the user requested to download the dataset - if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` and `download=False`, " - "either specify a different `root` directory or use `download=True` " - "to automatically download the dataset." - ) - - # Download the dataset - self._download() - self._extract() - - def _download(self) -> None: - """Download the dataset.""" - download_url(self.url, self.root, md5=self.md5 if self.checksum else None) + def _verify_data(self) -> bool: + """Verify if the images and masks are present.""" + img_query = os.path.join(self.root, "output", "*_*.jpg") + mask_query = os.path.join(self.root, "output", "*_*_m.png") + images = glob.glob(img_query) + masks = glob.glob(mask_query) + return len(images) > 0 and len(images) == len(masks) def _extract(self) -> None: """Extract the dataset. @@ -210,7 +403,7 @@ def _extract(self) -> None: Raises: AssertionError: if the checksum of split.py does not match """ - extract_archive(os.path.join(self.root, self.filename)) + super()._extract() # Generate train/val/test splits # Always check the sha256 of this file before executing @@ -220,51 +413,3 @@ def _extract(self) -> None: split = f.read().encode("utf-8") assert hashlib.sha256(split).hexdigest() == self.sha256 exec(split) - - def plot( - self, - sample: Dict[str, Tensor], - show_titles: bool = True, - suptitle: Optional[str] = None, - ) -> plt.Figure: - """Plot a sample from the dataset. - - Args: - sample: a sample returned by :meth:`__getitem__` - show_titles: flag indicating whether to show titles above each panel - suptitle: optional string to use as a suptitle - - Returns: - a matplotlib Figure with the rendered sample - - .. versionadded:: 0.2 - """ - image = np.rollaxis(sample["image"].numpy(), 0, 3) - mask = sample["mask"].numpy() - - num_panels = 2 - showing_predictions = "prediction" in sample - if showing_predictions: - predictions = sample["prediction"].numpy() - num_panels += 1 - - fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5)) - axs[0].imshow(image) - axs[0].axis("off") - axs[1].imshow(mask, vmin=0, vmax=4, cmap=self.cmap, interpolation="none") - axs[1].axis("off") - if show_titles: - axs[0].set_title("Image") - axs[1].set_title("Mask") - - if showing_predictions: - axs[2].imshow( - predictions, vmin=0, vmax=4, cmap=self.cmap, interpolation="none" - ) - axs[2].axis("off") - if show_titles: - axs[2].set_title("Predictions") - - if suptitle is not None: - plt.suptitle(suptitle) - return fig