diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 10d49dde8f5..d84ac56d4ce 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -169,8 +169,8 @@ Kenya Crop Type .. autoclass:: CV4AKenyaCropType -Deep Globe Land Cover -^^^^^^^^^^^^^^^^^^^^^ +DeepGlobe Land Cover +^^^^^^^^^^^^^^^^^^^^ .. autoclass:: DeepGlobeLandCover diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index ef5fa77744c..d8002401eab 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -5,7 +5,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `Cloud Cover Detection`_,S,Sentinel-2,"22,728",2,512x512,10,MSI `COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","388,435",2,256x256,0.15,RGB `Kenya Crop Type`_,S,Sentinel-2,"4,688",7,"3,035x2,016",10,MSI -`Deep Globe Land Cover`_,S,DigitalGlobe +Vivid,803,7,"2,448x2,448",0.5,RGB +`DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,803,7,"2,448x2,448",0.5,RGB `DFC2022`_,S,Aerial,"3,981",15,"2,000x2,000",0.5,RGB `ETCI2021 Flood Detection`_,S,Sentinel-1,"66,810",2,256x256,5--20,SAR `EuroSAT`_,C,Sentinel-2,"27,000",10,64x64,10,MSI @@ -34,4 +34,4 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `Vaihingen`_,S,Aerial,33,6,"1,281--3,816",0.09,RGB `NWPU VHR-10`_,I,"Google Earth, Vaihingen",800,10,"358--1,728",0.08--2,RGB `xView2`_,CD,Maxar,"3,732",4,"1,024x1,024",0.8,RGB -`ZueriCrop`_,"I, T",Sentinel-2,116K,48,24x24,10,MSI \ No newline at end of file +`ZueriCrop`_,"I, T",Sentinel-2,116K,48,24x24,10,MSI diff --git a/environment.yml b/environment.yml index 7b3d71f504b..4f8c98f7dea 100644 --- a/environment.yml +++ b/environment.yml @@ -22,7 +22,7 @@ dependencies: - flake8>=3.8 - ipywidgets>=7 - isort[colors]>=5.8 - - kornia>=0.6.4 + - kornia>=0.6.5 - laspy>=2 - mypy>=0.900 - nbmake>=0.1 diff --git a/requirements/min.old b/requirements/min.old index dae28f65f89..e35c12be2b7 100644 --- a/requirements/min.old +++ b/requirements/min.old @@ -4,7 +4,7 @@ setuptools==42.0.0 # install einops==0.3.0 fiona==1.8.0 -kornia==0.6.4 +kornia==0.6.5 matplotlib==3.3.0 numpy==1.17.2 omegaconf==2.1.0 diff --git a/setup.cfg b/setup.cfg index f1ff7e5bf9e..c371e59e7e3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -29,8 +29,8 @@ install_requires = einops>=0.3,<0.7 # fiona 1.8+ required for reading empty files fiona>=1.8,<2 - # kornia 0.6.4+ required for kornia.contrib.compute_padding - kornia>=0.6.4,<0.7 + # kornia 0.6.5+ required due to change in kornia.augmentation API + kornia>=0.6.5,<0.7 # matplotlib 3.3+ required for (H, W, 1) image support in plt.imshow matplotlib>=3.3,<4 # numpy 1.17.2+ required by pytorch-lightning diff --git a/tests/conf/deepglobelandcover_5.yaml b/tests/conf/deepglobelandcover.yaml similarity index 84% rename from tests/conf/deepglobelandcover_5.yaml rename to tests/conf/deepglobelandcover.yaml index b3262962bd0..2bb2cc5b53b 100644 --- a/tests/conf/deepglobelandcover_5.yaml +++ b/tests/conf/deepglobelandcover.yaml @@ -14,6 +14,8 @@ experiment: ignore_index: null datamodule: root: "tests/data/deepglobelandcover" + num_tiles_per_batch: 1 + num_patches_per_tile: 1 + patch_size: 2 val_split_pct: 0.5 - batch_size: 1 num_workers: 0 diff --git a/tests/conf/deepglobelandcover_0.yaml b/tests/conf/deepglobelandcover_0.yaml deleted file mode 100644 index a835df8995d..00000000000 --- a/tests/conf/deepglobelandcover_0.yaml +++ /dev/null @@ -1,19 +0,0 @@ -experiment: - task: "deepglobelandcover" - module: - loss: "ce" - model: "unet" - backbone: "resnet18" - weights: null - learning_rate: 1e-3 - learning_rate_schedule_patience: 6 - verbose: false - in_channels: 3 - num_classes: 7 - num_filters: 1 - ignore_index: null - datamodule: - root: "tests/data/deepglobelandcover" - val_split_pct: 0.0 - batch_size: 1 - num_workers: 0 diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 14fc4daeb20..73245f6c717 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -36,8 +36,7 @@ class TestSemanticSegmentationTask: "name,classname", [ ("chesapeake_cvpr_5", ChesapeakeCVPRDataModule), - ("deepglobelandcover_0", DeepGlobeLandCoverDataModule), - ("deepglobelandcover_5", DeepGlobeLandCoverDataModule), + ("deepglobelandcover", DeepGlobeLandCoverDataModule), ("etci2021", ETCI2021DataModule), ("inria_train", InriaAerialImageLabelingDataModule), ("inria_val", InriaAerialImageLabelingDataModule), diff --git a/torchgeo/datamodules/deepglobelandcover.py b/torchgeo/datamodules/deepglobelandcover.py index 955467048e3..ac71f69bd77 100644 --- a/torchgeo/datamodules/deepglobelandcover.py +++ b/torchgeo/datamodules/deepglobelandcover.py @@ -3,14 +3,17 @@ """DeepGlobe Land Cover Classification Challenge datamodule.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple, Union import matplotlib.pyplot as plt import pytorch_lightning as pl -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import Compose +from kornia.augmentation import Normalize +from torch.utils.data import DataLoader from ..datasets import DeepGlobeLandCover +from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from ..transforms.transforms import _ExtractTensorPatches, _RandomNCrop from .utils import dataset_split @@ -18,72 +21,74 @@ class DeepGlobeLandCoverDataModule(pl.LightningDataModule): """LightningDataModule implementation for the DeepGlobe Land Cover dataset. Uses the train/test splits from the dataset. - """ def __init__( self, - batch_size: int = 64, - num_workers: int = 0, + num_tiles_per_batch: int = 16, + num_patches_per_tile: int = 16, + patch_size: Union[Tuple[int, int], int] = 64, val_split_pct: float = 0.2, + num_workers: int = 0, **kwargs: Any, ) -> None: - """Initialize a LightningDataModule for DeepGlobe Land Cover based DataLoaders. + """Initialize a new LightningDataModule instance. + + The DeepGlobe Land Cover dataset contains images that are too large to pass + directly through a model. Instead, we randomly sample patches from image tiles + during training and chop up image tiles into patch grids during evaluation. + During training, the effective batch size is equal to + ``num_tiles_per_batch`` x ``num_patches_per_tile``. Args: - batch_size: The batch size to use in all created DataLoaders - num_workers: The number of workers to use in all created DataLoaders - val_split_pct: What percentage of the dataset to use as a validation set + num_tiles_per_batch: The number of image tiles to sample from during + training + num_patches_per_tile: The number of patches to randomly sample from each + image tile during training + patch_size: The size of each patch, either ``size`` or ``(height, width)``. + Should be a multiple of 32 for most segmentation architectures + val_split_pct: The percentage of the dataset to use as a validation set + num_workers: The number of workers to use for parallel data loading **kwargs: Additional keyword arguments passed to :class:`~torchgeo.datasets.DeepGlobeLandCover` + + .. versionchanged:: 0.4 + *batch_size* was replaced by *num_tile_per_batch*, *num_patches_per_tile*, + and *patch_size*. """ super().__init__() - self.batch_size = batch_size - self.num_workers = num_workers + + self.num_tiles_per_batch = num_tiles_per_batch + self.num_patches_per_tile = num_patches_per_tile + self.patch_size = _to_tuple(patch_size) self.val_split_pct = val_split_pct + self.num_workers = num_workers self.kwargs = kwargs - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset. - - Args: - sample: input image dictionary - - Returns: - preprocessed sample - """ - sample["image"] = sample["image"].float() - sample["image"] /= 255.0 - return sample + self.train_transform = AugmentationSequential( + Normalize(mean=0.0, std=255.0), + _RandomNCrop(self.patch_size, self.num_patches_per_tile), + data_keys=["image", "mask"], + ) + self.test_transform = AugmentationSequential( + Normalize(mean=0.0, std=255.0), + _ExtractTensorPatches(self.patch_size), + data_keys=["image", "mask"], + ) def setup(self, stage: Optional[str] = None) -> None: - """Initialize the main ``Dataset`` objects. + """Initialize the main Dataset objects. This method is called once per GPU per run. Args: stage: stage to set up """ - transforms = Compose([self.preprocess]) - - dataset = DeepGlobeLandCover( - split="train", transforms=transforms, **self.kwargs - ) - - self.train_dataset: Dataset[Any] - self.val_dataset: Dataset[Any] - - if self.val_split_pct > 0.0: - self.train_dataset, self.val_dataset, _ = dataset_split( - dataset, val_pct=self.val_split_pct, test_pct=0.0 - ) - else: - self.train_dataset = dataset - self.val_dataset = dataset - - self.test_dataset = DeepGlobeLandCover( - split="test", transforms=transforms, **self.kwargs + train_dataset = DeepGlobeLandCover(split="train", **self.kwargs) + self.train_dataset, self.val_dataset = dataset_split( + train_dataset, self.val_split_pct ) + self.test_dataset = DeepGlobeLandCover(split="test", **self.kwargs) def train_dataloader(self) -> DataLoader[Dict[str, Any]]: """Return a DataLoader for training. @@ -93,7 +98,7 @@ def train_dataloader(self) -> DataLoader[Dict[str, Any]]: """ return DataLoader( self.train_dataset, - batch_size=self.batch_size, + batch_size=self.num_tiles_per_batch, num_workers=self.num_workers, shuffle=True, ) @@ -105,10 +110,7 @@ def val_dataloader(self) -> DataLoader[Dict[str, Any]]: validation data loader """ return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, + self.val_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) def test_dataloader(self) -> DataLoader[Dict[str, Any]]: @@ -118,12 +120,35 @@ def test_dataloader(self) -> DataLoader[Dict[str, Any]]: testing data loader """ return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - num_workers=self.num_workers, - shuffle=False, + self.test_dataset, batch_size=1, num_workers=self.num_workers, shuffle=False ) + def on_after_batch_transfer( + self, batch: Dict[str, Any], dataloader_idx: int + ) -> Dict[str, Any]: + """Apply augmentations to batch after transferring to GPU. + + Args: + batch: A batch of data that needs to be altered or augmented + dataloader_idx: The index of the dataloader to which the batch belongs + + Returns: + A batch of data + """ + # Kornia requires masks to have a channel dimension + batch["mask"] = batch["mask"].unsqueeze(1) + + if self.trainer: + if self.trainer.training: + batch = self.train_transform(batch) + elif self.trainer.validating or self.trainer.testing: + batch = self.test_transform(batch) + + # Torchmetrics does not support masks with a channel dimension + batch["mask"] = batch["mask"].squeeze(1) + + return batch + def plot(self, *args: Any, **kwargs: Any) -> plt.Figure: """Run :meth:`torchgeo.datasets.DeepGlobeLandCover.plot`. diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 92d26dc87b2..2c72de9b69c 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -3,7 +3,7 @@ """InriaAerialImageLabeling datamodule.""" -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union import kornia.augmentation as K import matplotlib.pyplot as plt @@ -69,7 +69,7 @@ def __init__( self.num_workers = num_workers self.val_split_pct = val_split_pct self.test_split_pct = test_split_pct - self.patch_size = cast(Tuple[int, int], _to_tuple(patch_size)) + self.patch_size = _to_tuple(patch_size) self.num_patches_per_tile = num_patches_per_tile self.kwargs = kwargs diff --git a/torchgeo/datasets/deepglobelandcover.py b/torchgeo/datasets/deepglobelandcover.py index 7adcddcda19..1113e9bb64b 100644 --- a/torchgeo/datasets/deepglobelandcover.py +++ b/torchgeo/datasets/deepglobelandcover.py @@ -173,7 +173,7 @@ def _load_image(self, index: int) -> Tensor: array: "np.typing.NDArray[np.int_]" = np.array(img) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW - tensor = tensor.permute((2, 0, 1)) + tensor = tensor.permute((2, 0, 1)).to(torch.float32) return tensor def _load_target(self, index: int) -> Tensor: diff --git a/torchgeo/samplers/utils.py b/torchgeo/samplers/utils.py index ecf4cef3110..d36f9a0550d 100644 --- a/torchgeo/samplers/utils.py +++ b/torchgeo/samplers/utils.py @@ -4,13 +4,23 @@ """Common sampler utilities.""" import math -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, overload import torch from ..datasets import BoundingBox +@overload +def _to_tuple(value: Union[Tuple[int, int], int]) -> Tuple[int, int]: + ... + + +@overload +def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: + ... + + def _to_tuple(value: Union[Tuple[float, float], float]) -> Tuple[float, float]: """Convert value to a tuple if it is not already a tuple. diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index fe07980b3b3..3990b9e82dc 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -3,14 +3,19 @@ """TorchGeo transforms.""" -from typing import Dict, List, Union +from typing import Any, Dict, List, Optional, Tuple, Union -import kornia.augmentation as K +import kornia import torch +from kornia.augmentation import GeometricAugmentationBase2D +from kornia.augmentation.random_generator import CropGenerator +from kornia.contrib import compute_padding, extract_tensor_patches +from kornia.geometry import crop_by_indices from torch import Tensor from torch.nn.modules import Module +# TODO: contribute these to Kornia and delete this file class AugmentationSequential(Module): """Wrapper around kornia AugmentationSequential to handle input dicts.""" @@ -33,7 +38,7 @@ def __init__(self, *args: Module, data_keys: List[str]) -> None: else: keys.append(key) - self.augs = K.AugmentationSequential(*args, data_keys=keys) + self.augs = kornia.augmentation.AugmentationSequential(*args, data_keys=keys) def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: """Perform augmentations and update data dict. @@ -69,3 +74,147 @@ def forward(self, sample: Dict[str, Tensor]) -> Dict[str, Tensor]: sample["boxes"] = sample["boxes"].to(boxes_dtype) return sample + + +class _ExtractTensorPatches(GeometricAugmentationBase2D): + """Chop up a tensor into a grid.""" + + def __init__(self, window_size: Union[int, Tuple[int, int]]) -> None: + """Initialize a new _ExtractTensorPatches instance. + + Args: + window_size: the size of each patch + """ + super().__init__(p=1) + self.flags = {"window_size": window_size} + + def compute_transformation( + self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any] + ) -> Tensor: + """Compute the transformation. + + Args: + input: the input tensor + params: generated parameters + flags: static parameters + + Returns: + the transformation + """ + out: Tensor = self.identity_matrix(input) + return out + + def apply_transform( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ) -> Tensor: + """Apply the transform. + + Args: + input: the input tensor + params: generated parameters + flags: static parameters + transform: the geometric transformation tensor + + Returns: + the augmented input + """ + size = flags["window_size"] + h, w = input.shape[-2:] + padding = compute_padding((h, w), size) + input = extract_tensor_patches(input, size, size, padding) + input = torch.flatten(input, 0, 1) # [B, N, C?, H, W] -> [B*N, C?, H, W] + return input + + +class _RandomNCrop(GeometricAugmentationBase2D): + """Take N random crops of a tensor.""" + + def __init__(self, size: Tuple[int, int], num: int) -> None: + """Initialize a new _RandomNCrop instance. + + Args: + size: desired output size (out_h, out_w) of the crop + num: number of crops to take + """ + super().__init__(p=1) + self._param_generator: _NCropGenerator = _NCropGenerator(size, num) + self.flags = {"size": size, "num": num} + + def compute_transformation( + self, input: Tensor, params: Dict[str, Tensor], flags: Dict[str, Any] + ) -> Tensor: + """Compute the transformation. + + Args: + input: the input tensor + params: generated parameters + flags: static parameters + + Returns: + the transformation + """ + out: Tensor = self.identity_matrix(input) + return out + + def apply_transform( + self, + input: Tensor, + params: Dict[str, Tensor], + flags: Dict[str, Any], + transform: Optional[Tensor] = None, + ) -> Tensor: + """Apply the transform. + + Args: + input: the input tensor + params: generated parameters + flags: static parameters + transform: the geometric transformation tensor + + Returns: + the augmented input + """ + out = [] + for i in range(flags["num"]): + out.append(crop_by_indices(input, params["src"][i], flags["size"])) + return torch.cat(out) + + +class _NCropGenerator(CropGenerator): + """Generate N random crops.""" + + def __init__(self, size: Union[Tuple[int, int], Tensor], num: int) -> None: + """Initialize a new _NCropGenerator instance. + + Args: + size: desired output size (out_h, out_w) of the crop + num: number of crops to generate + """ + super().__init__(size) + self.num = num + + def forward( + self, batch_shape: torch.Size, same_on_batch: bool = False + ) -> Dict[str, Tensor]: + """Generate the crops. + + Args: + batch_shape: input size (b, c?, in_h, in_w) + same_on_batch: apply the same transformation across the batch + + Returns: + the randomly generated parameters + """ + out = [] + for _ in range(self.num): + out.append(super().forward(batch_shape, same_on_batch)) + return { + "src": torch.stack([x["src"] for x in out]), + "dst": torch.stack([x["dst"] for x in out]), + "input_size": out[0]["input_size"], + "output_size": out[0]["output_size"], + }