Skip to content

Commit

Permalink
SSL4EO-S12: add new dataset/datamodule (microsoft#1151)
Browse files Browse the repository at this point in the history
* SSL4EO-S12: add new dataset

* Style fixes

* 100% coverage

* fix mypy

* black fixes

* mypy fix

* Convert from db to power

* Don't cast to numpy

* Remove comments referring to SeCo

* SSL4EO: add extraction time

* Add RandomSeasonContrast

* Fix axes indexing

* Add datamodule

* fix tests

* mypy fixes

* fix missing import

* Fix tests

* isort fix

* Typo fix

* s2c: add B10

* Update test channels

* S2 plotting was broken

* Fix plotting

* Black fix

* Rename conf files

* Remove file introduced by bad merge

* Fix pixel size of bands

* black fix

* Better S1 plotting

---------

Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
  • Loading branch information
adamjstewart and calebrob6 authored Apr 15, 2023
1 parent 22e8040 commit 7c1d05a
Show file tree
Hide file tree
Showing 234 changed files with 581 additions and 5 deletions.
5 changes: 5 additions & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ SpaceNet

.. autoclass:: SpaceNet1DataModule

SSL4EO
^^^^^^

.. autoclass:: SSL4EOS12DataModule

Tropical Cyclone
^^^^^^^^^^^^^^^^

Expand Down
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,11 @@ SpaceNet
.. autoclass:: SpaceNet6
.. autoclass:: SpaceNet7

SSL4EO
^^^^^^

.. autoclass:: SSL4EOS12

Tropical Cyclone
^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands
`SEN12MS`_,S,"Sentinel-1/2, MODIS","180,662",33,256x256,10,"SAR, MSI"
`So2Sat`_,C,Sentinel-1/2,"400,673",17,32x32,10,"SAR, MSI"
`SpaceNet`_,I,WorldView-2/3 Planet Lab Dove,"1,889--28,728",2,102--900,0.5--4,MSI
`SSL4EO`_,T,Sentinel-1/2,1M,-,264x264,10,"SAR, MSI"
`Tropical Cyclone`_,R,GOES 8--16,"108,110",-,256x256,4K--8K,MSI
`UC Merced`_,C,USGS National Map,"21,000",21,256x256,0.3,RGB
`USAVars`_,R,NAIP Aerial,100K,-,-,4,"RGB, NIR"
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
13 changes: 13 additions & 0 deletions tests/conf/ssl4eo_s12_byol_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
experiment:
task: "ssl4eo_s12"
module:
in_channels: 13
backbone: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
datamodule:
root: "tests/data/ssl4eo/s12"
seasons: 1
batch_size: 2
num_workers: 0
13 changes: 13 additions & 0 deletions tests/conf/ssl4eo_s12_byol_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
experiment:
task: "ssl4eo_s12"
module:
in_channels: 13
backbone: "resnet18"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
weights: null
datamodule:
root: "tests/data/ssl4eo/s12"
seasons: 2
batch_size: 2
num_workers: 0
167 changes: 167 additions & 0 deletions tests/data/ssl4eo/s12/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import hashlib
import os
import shutil
from typing import Dict, List, Union

import numpy as np
import rasterio
from rasterio import Affine
from rasterio.crs import CRS

SIZE = 36

np.random.seed(0)

FILENAME_HIERARCHY = Union[Dict[str, "FILENAME_HIERARCHY"], List[str]]

s1 = ["VH.tif", "VV.tif"]
s2c = [
"B1.tif",
"B2.tif",
"B3.tif",
"B4.tif",
"B5.tif",
"B6.tif",
"B7.tif",
"B8.tif",
"B8A.tif",
"B9.tif",
"B10.tif",
"B11.tif",
"B12.tif",
]
s2a = s2c.copy()
s2a.remove("B10.tif")
filenames: FILENAME_HIERARCHY = {
"s1": {
"0000000": {
"S1A_IW_GRDH_1SDV_20200329T001515_20200329T001540_031883_03AE27_9BAF": s1,
"S1A_IW_GRDH_1SDV_20201230T001523_20201230T001548_035908_04349D_C91E": s1,
"S1B_IW_GRDH_1SDV_20200627T001449_20200627T001514_022212_02A27E_2A09": s1,
"S1B_IW_GRDH_1SDV_20200928T120105_20200928T120130_023575_02CCB0_F035": s1,
},
"0000001": {
"S1B_IW_GRDH_1SDV_20201101T091054_20201101T091119_024069_02DC0F_F189": s1,
"S1B_IW_GRDH_1SDV_20210205T091050_20210205T091115_025469_0308CB_AA25": s1,
"S1B_IW_GRDH_1SDV_20210430T091051_20210430T091116_026694_03303D_69B6": s1,
"S1B_IW_GRDH_1SDV_20210804T091057_20210804T091122_028094_0359FE_6D9D": s1,
},
},
"s2c": {
"0000000": {
"20200323T162931_20200323T163750_T15QXA": s2c,
"20200621T162901_20200621T164746_T15QXA": s2c,
"20200924T162929_20200924T164434_T15QXA": s2c,
"20201228T163711_20201228T164519_T15QXA": s2c,
},
"0000001": {
"20201104T135121_20201104T135117_T21KXT": s2c,
"20210123T135111_20210123T135113_T21KXT": s2c,
"20210508T135109_20210508T135519_T21KXT": s2c,
"20210811T135121_20210811T135115_T21KXT": s2c,
},
},
"s2a": {
"0000000": {
"20200323T162931_20200323T163750_T15QXA": s2a,
"20200621T162901_20200621T164746_T15QXA": s2a,
"20200924T162929_20200924T164434_T15QXA": s2a,
"20201228T163711_20201228T164519_T15QXA": s2a,
},
"0000001": {
"20201104T135121_20201104T135117_T21KXT": s2a,
"20210123T135111_20210123T135113_T21KXT": s2a,
"20210508T135109_20210508T135519_T21KXT": s2a,
"20210811T135121_20210811T135115_T21KXT": s2a,
},
},
}


def create_file(path: str) -> None:
profile = {
"driver": "GTiff",
"dtype": "uint16",
"width": SIZE,
"height": SIZE,
"count": 1,
"crs": CRS.from_epsg(4326),
"transform": Affine(
9.221577104649252e-05,
0.0,
-91.84569595740037,
0.0,
-8.79720059797404e-05,
18.588102959877993,
),
}

if path.endswith("VH.tif") or path.endswith("VV.tif"):
profile["dtype"] = "float32"

if path.endswith("B1.tif") or path.endswith("B9.tif") or path.endswith("B10.tif"):
profile["width"] = profile["height"] = SIZE // 6
profile["transform"] = Affine(
0.0005532946262789551,
0.0,
-91.84592649682799,
0.0,
-0.0005278320358784425,
18.588322889892943,
)
elif (
path.endswith("B5.tif")
or path.endswith("B6.tif")
or path.endswith("B7.tif")
or path.endswith("B8A.tif")
or path.endswith("B11.tif")
or path.endswith("B12.tif")
):
profile["width"] = profile["height"] = SIZE // 2
profile["transform"] = Affine(
0.00018443154209298504,
0.0,
-91.84574206528589,
0.0,
-0.0001759440119594808,
18.588146945880982,
)

Z = np.random.randn(SIZE, SIZE).astype(profile["dtype"])

with rasterio.open(path, "w", **profile) as src:
src.write(Z, 1)


def create_directory(directory: str, hierarchy: FILENAME_HIERARCHY) -> None:
if isinstance(hierarchy, dict):
# Recursive case
for key, value in hierarchy.items():
path = os.path.join(directory, key)
os.makedirs(path, exist_ok=True)
create_directory(path, value)
else:
# Base case
for value in hierarchy:
path = os.path.join(directory, value)
create_file(path)


if __name__ == "__main__":
create_directory(".", filenames)

files = ["s1", "s2_l1c", "s2_l2a"]
directories = ["s1", "s2c", "s2a"]
for file, directory in zip(files, directories):
# Create tarballs
shutil.make_archive(file, "gztar", ".", directory)

# Compute checksums
with open(f"{file}.tar.gz", "rb") as f:
md5 = hashlib.md5(f.read()).hexdigest()
print(file, md5)
Binary file added tests/data/ssl4eo/s12/s1.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/ssl4eo/s12/s2_l1c.tar.gz
Binary file not shown.
Binary file added tests/data/ssl4eo/s12/s2_l2a.tar.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
74 changes: 74 additions & 0 deletions tests/datasets/test_ssl4eo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
import shutil
from pathlib import Path
from typing import cast

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from _pytest.monkeypatch import MonkeyPatch
from torch.utils.data import ConcatDataset

from torchgeo.datasets import SSL4EOS12


class TestSSL4EOS12:
@pytest.fixture(params=zip(SSL4EOS12.metadata.keys(), [1, 1, 2]))
def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> SSL4EOS12:
monkeypatch.setitem(
SSL4EOS12.metadata["s1"], "md5", "a716f353e4c2f0014f2e1f1ad848f82e"
)
monkeypatch.setitem(
SSL4EOS12.metadata["s2c"], "md5", "85eaf474af5642588a97dc5c991cfc15"
)
monkeypatch.setitem(
SSL4EOS12.metadata["s2a"], "md5", "df41a5d1ae6f840bc9a11ee254110369"
)

root = os.path.join("tests", "data", "ssl4eo", "s12")
split, seasons = request.param
transforms = nn.Identity()
return SSL4EOS12(root, split, seasons, transforms, checksum=True)

def test_getitem(self, dataset: SSL4EOS12) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x["image"], torch.Tensor)
assert x["image"].size(0) == dataset.seasons * len(dataset.bands)

def test_len(self, dataset: SSL4EOS12) -> None:
assert len(dataset) == 251079

def test_add(self, dataset: SSL4EOS12) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 2 * 251079

def test_extract(self, tmp_path: Path) -> None:
for split in SSL4EOS12.metadata:
filename = cast(str, SSL4EOS12.metadata[split]["filename"])
shutil.copyfile(
os.path.join("tests", "data", "ssl4eo", "s12", filename),
tmp_path / filename,
)
SSL4EOS12(str(tmp_path))

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
SSL4EOS12(split="foo")

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
SSL4EOS12(str(tmp_path))

def test_plot(self, dataset: SSL4EOS12) -> None:
sample = dataset[0]
dataset.plot(sample, suptitle="Test")
plt.close()
dataset.plot(sample, show_titles=False)
plt.close()
19 changes: 14 additions & 5 deletions tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
from torchvision.models import resnet18
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import ChesapeakeCVPRDataModule, SeasonalContrastS2DataModule
from torchgeo.datasets import SeasonalContrastS2
from torchgeo.datamodules import (
ChesapeakeCVPRDataModule,
SeasonalContrastS2DataModule,
SSL4EOS12DataModule,
)
from torchgeo.datasets import SSL4EOS12, SeasonalContrastS2
from torchgeo.models import get_model_weights, list_models
from torchgeo.trainers import BYOLTask
from torchgeo.trainers.byol import BYOL, SimCLRAugmentation
Expand Down Expand Up @@ -52,9 +56,11 @@ class TestBYOLTask:
@pytest.mark.parametrize(
"name,classname",
[
("chesapeake_cvpr_prior", ChesapeakeCVPRDataModule),
("seco_1", SeasonalContrastS2DataModule),
("seco_2", SeasonalContrastS2DataModule),
("chesapeake_cvpr_prior_byol", ChesapeakeCVPRDataModule),
("seco_byol_1", SeasonalContrastS2DataModule),
("seco_byol_2", SeasonalContrastS2DataModule),
("ssl4eo_s12_byol_1", SSL4EOS12DataModule),
("ssl4eo_s12_byol_2", SSL4EOS12DataModule),
],
)
def test_trainer(
Expand All @@ -71,6 +77,9 @@ def test_trainer(
if name.startswith("seco"):
monkeypatch.setattr(SeasonalContrastS2, "__len__", lambda self: 2)

if name.startswith("ssl4eo_s12"):
monkeypatch.setattr(SSL4EOS12, "__len__", lambda self: 2)

# Instantiate datamodule
datamodule_kwargs = conf_dict["datamodule"]
datamodule = classname(**datamodule_kwargs)
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .sen12ms import SEN12MSDataModule
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
from .ssl4eo import SSL4EOS12DataModule
from .ucmerced import UCMercedDataModule
from .usavars import USAVarsDataModule
from .utils import MisconfigurationException
Expand Down Expand Up @@ -59,6 +60,7 @@
"SEN12MSDataModule",
"So2SatDataModule",
"SpaceNet1DataModule",
"SSL4EOS12DataModule",
"TropicalCycloneDataModule",
"UCMercedDataModule",
"USAVarsDataModule",
Expand Down
Loading

0 comments on commit 7c1d05a

Please sign in to comment.