diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst
index 87549c9719f..d5af9d67434 100644
--- a/docs/api/datasets.rst
+++ b/docs/api/datasets.rst
@@ -93,6 +93,12 @@ iNaturalist
.. autoclass:: INaturalist
+LandCover.ai Geo
+^^^^^^^^^^^^^^^^
+
+.. autoclass:: LandCoverAIBase
+.. autoclass:: LandCoverAIGeo
+
Landsat
^^^^^^^
diff --git a/docs/api/geo_datasets.csv b/docs/api/geo_datasets.csv
index 9bee384cd95..66b9245b865 100644
--- a/docs/api/geo_datasets.csv
+++ b/docs/api/geo_datasets.csv
@@ -12,6 +12,7 @@ Dataset,Type,Source,Size (px),Resolution (m)
`GBIF`_,Points,Citizen Scientists,-,-
`GlobBiomass`_,Masks,Landsat,"45,000x45,000",100
`iNaturalist`_,Points,Citizen Scientists,-,-
+`LandCover.ai Geo`_,"Imagery, Masks",Aerial,"4,200--9,500",0.25--0.5
`Landsat`_,Imagery,Landsat,"8,900x8,900",30
`NAIP`_,Imagery,Aerial,"6,100x7,600",1
`Open Buildings`_,Geometries,"Maxar, CNES/Airbus",-,-
diff --git a/tests/datasets/test_landcoverai.py b/tests/datasets/test_landcoverai.py
index e8e64680ef0..c4f11e05b61 100644
--- a/tests/datasets/test_landcoverai.py
+++ b/tests/datasets/test_landcoverai.py
@@ -14,13 +14,62 @@
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
-from torchgeo.datasets import LandCoverAI
+from torchgeo.datasets import BoundingBox, LandCoverAI, LandCoverAIGeo
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
+class TestLandCoverAIGeo:
+ @pytest.fixture
+ def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> LandCoverAIGeo:
+ monkeypatch.setattr(torchgeo.datasets.landcoverai, "download_url", download_url)
+ md5 = "ff8998857cc8511f644d3f7d0f3688d0"
+ monkeypatch.setattr(LandCoverAIGeo, "md5", md5)
+ url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip")
+ monkeypatch.setattr(LandCoverAIGeo, "url", url)
+ root = str(tmp_path)
+ transforms = nn.Identity()
+ return LandCoverAIGeo(root, transforms=transforms, download=True, checksum=True)
+
+ def test_getitem(self, dataset: LandCoverAIGeo) -> None:
+ x = dataset[dataset.bounds]
+ assert isinstance(x, dict)
+ assert isinstance(x["image"], torch.Tensor)
+ assert isinstance(x["mask"], torch.Tensor)
+
+ def test_already_extracted(self, dataset: LandCoverAIGeo) -> None:
+ LandCoverAIGeo(root=dataset.root, download=True)
+
+ def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
+ url = os.path.join("tests", "data", "landcoverai", "landcover.ai.v1.zip")
+ root = str(tmp_path)
+ shutil.copy(url, root)
+ LandCoverAIGeo(root)
+
+ def test_not_downloaded(self, tmp_path: Path) -> None:
+ with pytest.raises(RuntimeError, match="Dataset not found"):
+ LandCoverAIGeo(str(tmp_path))
+
+ def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None:
+ query = BoundingBox(0, 0, 0, 0, 0, 0)
+ with pytest.raises(
+ IndexError, match="query: .* not found in index with bounds:"
+ ):
+ dataset[query]
+
+ def test_plot(self, dataset: LandCoverAIGeo) -> None:
+ x = dataset[dataset.bounds].copy()
+ dataset.plot(x, suptitle="Test")
+ plt.close()
+ dataset.plot(x, show_titles=False)
+ plt.close()
+ x["prediction"] = x["mask"][:, :, 0].clone().unsqueeze(2)
+ dataset.plot(x)
+ plt.close()
+
+
class TestLandCoverAI:
@pytest.fixture(params=["train", "val", "test"])
def dataset(
diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py
index 724fafad675..5bcf3e65195 100644
--- a/torchgeo/datasets/__init__.py
+++ b/torchgeo/datasets/__init__.py
@@ -53,7 +53,7 @@
from .idtrees import IDTReeS
from .inaturalist import INaturalist
from .inria import InriaAerialImageLabeling
-from .landcoverai import LandCoverAI
+from .landcoverai import LandCoverAI, LandCoverAIBase, LandCoverAIGeo
from .landsat import (
Landsat,
Landsat1,
@@ -137,6 +137,8 @@
"GBIF",
"GlobBiomass",
"INaturalist",
+ "LandCoverAIBase",
+ "LandCoverAIGeo",
"Landsat",
"Landsat1",
"Landsat2",
diff --git a/torchgeo/datasets/landcoverai.py b/torchgeo/datasets/landcoverai.py
index c9e7c80c41f..aad853bb2f2 100644
--- a/torchgeo/datasets/landcoverai.py
+++ b/torchgeo/datasets/landcoverai.py
@@ -2,31 +2,33 @@
# Licensed under the MIT License.
"""LandCover.ai dataset."""
-
+import abc
import glob
import hashlib
import os
from functools import lru_cache
-from typing import Callable, Dict, Optional
+from typing import Any, Callable, Dict, List, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import ListedColormap
from PIL import Image
+from rasterio.crs import CRS
from torch import Tensor
+from torch.utils.data import Dataset
-from .geo import NonGeoDataset
-from .utils import download_url, extract_archive, working_dir
+from .geo import NonGeoDataset, RasterDataset
+from .utils import BoundingBox, download_url, extract_archive, working_dir
-class LandCoverAI(NonGeoDataset):
- r"""LandCover.ai dataset.
+class LandCoverAIBase(Dataset[Dict[str, Any]], abc.ABC):
+ r"""Abstract base class for LandCover.ai Geo and NonGeo datasets.
- The `LandCover.ai `__ (Land Cover from Aerial Imagery)
- dataset is a dataset for automatic mapping of buildings, woodlands, water and
- roads from aerial images. This implementation is specifically for Version 1 of
- Landcover.ai.
+ The `LandCover.ai `__ (Land Cover from
+ Aerial Imagery) dataset is a dataset for automatic mapping of buildings, woodlands,
+ water and roads from aerial images. This implementation is specifically for
+ Version 1 of LandCover.ai.
Dataset features:
@@ -50,7 +52,241 @@ class LandCoverAI(NonGeoDataset):
If you use this dataset in your research, please cite the following paper:
- * https://arxiv.org/abs/2005.02264v3
+ * https://arxiv.org/abs/2005.02264v4
+
+ .. versionadded:: 0.5
+ """
+
+ url = "https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip"
+ filename = "landcover.ai.v1.zip"
+ md5 = "3268c89070e8734b4e91d531c0617e03"
+ classes = ["Background", "Building", "Woodland", "Water", "Road"]
+ cmap = {
+ 0: (0, 0, 0, 0),
+ 1: (97, 74, 74, 255),
+ 2: (38, 115, 0, 255),
+ 3: (0, 197, 255, 255),
+ 4: (207, 207, 207, 255),
+ }
+
+ def __init__(
+ self, root: str = "data", download: bool = False, checksum: bool = False
+ ) -> None:
+ """Initialize a new LandCover.ai dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ cache: if True, cache file handle to speed up repeated sampling
+ download: if True, download dataset and store it in the root directory
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ RuntimeError: if ``download=False`` and data is not found, or checksums
+ don't match
+ """
+ self.root = root
+ self.download = download
+ self.checksum = checksum
+
+ lc_colors = np.zeros((max(self.cmap.keys()) + 1, 4))
+ lc_colors[list(self.cmap.keys())] = list(self.cmap.values())
+ lc_colors = lc_colors[:, :3] / 255
+ self._lc_cmap = ListedColormap(lc_colors)
+
+ self._verify()
+
+ def _verify(self) -> None:
+ """Verify the integrity of the dataset.
+
+ Raises:
+ RuntimeError: if ``download=False`` but dataset is missing or checksum fails
+ """
+ if self._verify_data():
+ return
+
+ # Check if the zip file has already been downloaded
+ pathname = os.path.join(self.root, self.filename)
+ if os.path.exists(pathname):
+ self._extract()
+ return
+
+ # Check if the user requested to download the dataset
+ if not self.download:
+ raise RuntimeError(
+ f"Dataset not found in `root={self.root}` and `download=False`, "
+ "either specify a different `root` directory or use `download=True` "
+ "to automatically download the dataset."
+ )
+
+ # Download the dataset
+ self._download()
+ self._extract()
+
+ @abc.abstractmethod
+ def __getitem__(self, query: Any) -> Dict[str, Any]:
+ """Retrieve image, mask and metadata indexed by index.
+
+ Args:
+ query: coordinates or an index
+
+ Returns:
+ sample of image, mask and metadata at that index
+
+ Raises:
+ IndexError: if query is not found in the index
+ """
+
+ @abc.abstractmethod
+ def _verify_data(self) -> bool:
+ """Verify if the images and masks are present."""
+
+ def _download(self) -> None:
+ """Download the dataset."""
+ download_url(self.url, self.root, md5=self.md5 if self.checksum else None)
+
+ def _extract(self) -> None:
+ """Extract the dataset."""
+ extract_archive(os.path.join(self.root, self.filename))
+
+ def plot(
+ self,
+ sample: Dict[str, Tensor],
+ show_titles: bool = True,
+ suptitle: Optional[str] = None,
+ ) -> plt.Figure:
+ """Plot a sample from the dataset.
+
+ Args:
+ sample: a sample returned by :meth:`__getitem__`
+ show_titles: flag indicating whether to show titles above each panel
+ suptitle: optional string to use as a suptitle
+
+ Returns:
+ a matplotlib Figure with the rendered sample
+ """
+ image = np.rollaxis(sample["image"].numpy().astype("uint8").squeeze(), 0, 3)
+ mask = sample["mask"].numpy().astype("uint8").squeeze()
+
+ num_panels = 2
+ showing_predictions = "prediction" in sample
+ if showing_predictions:
+ predictions = sample["prediction"].numpy()
+ num_panels += 1
+
+ fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5))
+ axs[0].imshow(image)
+ axs[0].axis("off")
+ axs[1].imshow(mask, vmin=0, vmax=4, cmap=self._lc_cmap, interpolation="none")
+ axs[1].axis("off")
+ if show_titles:
+ axs[0].set_title("Image")
+ axs[1].set_title("Mask")
+
+ if showing_predictions:
+ axs[2].imshow(
+ predictions, vmin=0, vmax=4, cmap=self._lc_cmap, interpolation="none"
+ )
+ axs[2].axis("off")
+ if show_titles:
+ axs[2].set_title("Predictions")
+
+ if suptitle is not None:
+ plt.suptitle(suptitle)
+ return fig
+
+
+class LandCoverAIGeo(LandCoverAIBase, RasterDataset):
+ """LandCover.ai Geo dataset.
+
+ See the abstract LandCoverAIBase class to find out more.
+
+ .. versionadded:: 0.5
+ """
+
+ filename_glob = os.path.join("images", "*.tif")
+ filename_regex = ".*tif"
+
+ def __init__(
+ self,
+ root: str = "data",
+ crs: Optional[CRS] = None,
+ res: Optional[float] = None,
+ transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
+ cache: bool = True,
+ download: bool = False,
+ checksum: bool = False,
+ ) -> None:
+ """Initialize a new LandCover.ai NonGeo dataset instance.
+
+ Args:
+ root: root directory where dataset can be found
+ crs: :term:`coordinate reference system (CRS)` to warp to
+ (defaults to the CRS of the first file found)
+ res: resolution of the dataset in units of CRS
+ (defaults to the resolution of the first file found)
+ transforms: a function/transform that takes input sample and its target as
+ entry and returns a transformed version
+ cache: if True, cache file handle to speed up repeated sampling
+ download: if True, download dataset and store it in the root directory
+ checksum: if True, check the MD5 of the downloaded files (may be slow)
+
+ Raises:
+ RuntimeError: if ``download=False`` and data is not found, or checksums
+ don't match
+ """
+ LandCoverAIBase.__init__(self, root, download, checksum)
+ RasterDataset.__init__(self, root, crs, res, transforms=transforms, cache=cache)
+
+ def _verify_data(self) -> bool:
+ """Verify if the images and masks are present."""
+ img_query = os.path.join(self.root, "images", "*.tif")
+ mask_query = os.path.join(self.root, "masks", "*.tif")
+ images = glob.glob(img_query)
+ masks = glob.glob(mask_query)
+ return len(images) > 0 and len(images) == len(masks)
+
+ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
+ """Retrieve image/mask and metadata indexed by query.
+
+ Args:
+ query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
+
+ Returns:
+ sample of image, mask and metadata at that index
+
+ Raises:
+ IndexError: if query is not found in the index
+ """
+ hits = self.index.intersection(tuple(query), objects=True)
+ img_filepaths = cast(List[str], [hit.object for hit in hits])
+ mask_filepaths = [path.replace("images", "masks") for path in img_filepaths]
+
+ if not img_filepaths:
+ raise IndexError(
+ f"query: {query} not found in index with bounds: {self.bounds}"
+ )
+
+ img = self._merge_files(img_filepaths, query, self.band_indexes)
+ mask = self._merge_files(mask_filepaths, query, self.band_indexes)
+ sample = {
+ "crs": self.crs,
+ "bbox": query,
+ "image": img.float(),
+ "mask": mask.long(),
+ }
+
+ if self.transforms is not None:
+ sample = self.transforms(sample)
+
+ return sample
+
+
+class LandCoverAI(LandCoverAIBase, NonGeoDataset):
+ """LandCover.ai dataset.
+
+ See the abstract LandCoverAIBase class to find out more.
.. note::
@@ -60,20 +296,7 @@ class LandCoverAI(NonGeoDataset):
the train/val/test split
"""
- url = "https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip"
- filename = "landcover.ai.v1.zip"
- md5 = "3268c89070e8734b4e91d531c0617e03"
sha256 = "15ee4ca9e3fd187957addfa8f0d74ac31bc928a966f76926e11b3c33ea76daa1"
- classes = ["Background", "Building", "Woodland", "Water", "Road"]
- cmap = ListedColormap(
- [
- [0.63921569, 1.0, 0.45098039],
- [0.61176471, 0.61176471, 0.61176471],
- [0.14901961, 0.45098039, 0.0],
- [0.0, 0.77254902, 1.0],
- [0.0, 0.0, 0.0],
- ]
- )
def __init__(
self,
@@ -100,14 +323,10 @@ def __init__(
"""
assert split in ["train", "val", "test"]
- self.root = root
- self.split = split
- self.transforms = transforms
- self.download = download
- self.checksum = checksum
-
- self._verify()
+ super().__init__(root, download, checksum)
+ self.transforms = transforms
+ self.split = split
with open(os.path.join(self.root, split + ".txt")) as f:
self.ids = f.readlines()
@@ -170,39 +389,13 @@ def _load_target(self, id_: str) -> Tensor:
tensor = torch.from_numpy(array).long()
return tensor
- def _verify(self) -> None:
- """Verify the integrity of the dataset.
-
- Raises:
- RuntimeError: if ``download=False`` but dataset is missing or checksum fails
- """
- # Check if the extracted files already exist
- jpg = os.path.join(self.root, "output", "*_*.jpg")
- png = os.path.join(self.root, "output", "*_*_m.png")
- if glob.glob(jpg) and glob.glob(png):
- return
-
- # Check if the zip file has already been downloaded
- pathname = os.path.join(self.root, self.filename)
- if os.path.exists(pathname):
- self._extract()
- return
-
- # Check if the user requested to download the dataset
- if not self.download:
- raise RuntimeError(
- f"Dataset not found in `root={self.root}` and `download=False`, "
- "either specify a different `root` directory or use `download=True` "
- "to automatically download the dataset."
- )
-
- # Download the dataset
- self._download()
- self._extract()
-
- def _download(self) -> None:
- """Download the dataset."""
- download_url(self.url, self.root, md5=self.md5 if self.checksum else None)
+ def _verify_data(self) -> bool:
+ """Verify if the images and masks are present."""
+ img_query = os.path.join(self.root, "output", "*_*.jpg")
+ mask_query = os.path.join(self.root, "output", "*_*_m.png")
+ images = glob.glob(img_query)
+ masks = glob.glob(mask_query)
+ return len(images) > 0 and len(images) == len(masks)
def _extract(self) -> None:
"""Extract the dataset.
@@ -210,7 +403,7 @@ def _extract(self) -> None:
Raises:
AssertionError: if the checksum of split.py does not match
"""
- extract_archive(os.path.join(self.root, self.filename))
+ super()._extract()
# Generate train/val/test splits
# Always check the sha256 of this file before executing
@@ -220,51 +413,3 @@ def _extract(self) -> None:
split = f.read().encode("utf-8")
assert hashlib.sha256(split).hexdigest() == self.sha256
exec(split)
-
- def plot(
- self,
- sample: Dict[str, Tensor],
- show_titles: bool = True,
- suptitle: Optional[str] = None,
- ) -> plt.Figure:
- """Plot a sample from the dataset.
-
- Args:
- sample: a sample returned by :meth:`__getitem__`
- show_titles: flag indicating whether to show titles above each panel
- suptitle: optional string to use as a suptitle
-
- Returns:
- a matplotlib Figure with the rendered sample
-
- .. versionadded:: 0.2
- """
- image = np.rollaxis(sample["image"].numpy(), 0, 3)
- mask = sample["mask"].numpy()
-
- num_panels = 2
- showing_predictions = "prediction" in sample
- if showing_predictions:
- predictions = sample["prediction"].numpy()
- num_panels += 1
-
- fig, axs = plt.subplots(1, num_panels, figsize=(num_panels * 4, 5))
- axs[0].imshow(image)
- axs[0].axis("off")
- axs[1].imshow(mask, vmin=0, vmax=4, cmap=self.cmap, interpolation="none")
- axs[1].axis("off")
- if show_titles:
- axs[0].set_title("Image")
- axs[1].set_title("Mask")
-
- if showing_predictions:
- axs[2].imshow(
- predictions, vmin=0, vmax=4, cmap=self.cmap, interpolation="none"
- )
- axs[2].axis("off")
- if show_titles:
- axs[2].set_title("Predictions")
-
- if suptitle is not None:
- plt.suptitle(suptitle)
- return fig