Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AgriFieldNet India Challenge dataset #1459

Merged
merged 52 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
7e23f5b
add agrifieldnet dataset
yichiac Jul 3, 2023
19b379a
modified len check
yichiac Jul 3, 2023
bd6aa4e
improve _download
yichiac Jul 3, 2023
aba7d9f
Merge branch 'main' into datasets/agrifieldnet
yichiac Jul 3, 2023
f7d0c40
remove augmentation and wrong datamodule names
yichiac Jul 5, 2023
7e715b7
Merge branch 'microsoft:main' into datasets/agrifieldnet
yichiac Jul 7, 2023
31f7645
update data.py and dataset
yichiac Jul 7, 2023
c17bc94
Merge branch 'main' into datasets/agrifieldnet
yichiac Jul 8, 2023
58e9c92
update splits
yichiac Jul 9, 2023
d28238a
remove patch_size change
yichiac Jul 9, 2023
b3f769d
fix style issues
yichiac Jul 9, 2023
304fd1d
add yaml and modify/test for training
yichiac Jul 10, 2023
7aa7287
Merge branch 'microsoft:main' into datasets/agrifieldnet
yichiac Jul 11, 2023
812ae2f
fix data path and add trainer
yichiac Jul 12, 2023
ee764e1
export prediction
yichiac Jul 13, 2023
511a069
Merge branch 'microsoft:main' into datasets/agrifieldnet
yichiac Jul 13, 2023
337e34b
fix integrity check and len
yichiac Jul 14, 2023
f63fa0f
extract predction
yichiac Jul 16, 2023
a73e591
adding create submission file function
yichiac Jul 24, 2023
2eb69c9
adding create submission file function
yichiac Jul 24, 2023
5b1ff20
Merge branch 'main' into datasets/agrifieldnet
yichiac Jul 24, 2023
da778f8
hyperparam tuning exp
yichiac Aug 3, 2023
6bb7e71
backup experiments
yichiac Aug 25, 2023
b82a0d1
Merge branch 'microsoft:main' into datasets/agrifieldnet
yichiac Aug 25, 2023
387d2fc
remove redundant files
yichiac Aug 25, 2023
29e7b6d
reverse segmentation.py
yichiac Aug 25, 2023
bd00344
resolve minor issues
yichiac Aug 25, 2023
0bf7805
modify yaml and add exp files
yichiac Aug 27, 2023
faab754
update data.py
yichiac Oct 23, 2023
6862277
remove outdated train.py
yichiac Jan 28, 2024
ee6cb66
Merge branch 'main' into datasets/agrifieldnet
yichiac Jan 28, 2024
7123563
update dataset, test, and new data
yichiac Feb 2, 2024
f62271e
Merge branch 'microsoft:main' into datasets/agrifieldnet
yichiac Feb 2, 2024
7c695fa
fix style
yichiac Feb 2, 2024
09cad2c
fix doc api
yichiac Feb 2, 2024
1b5ba21
remove datamodule
yichiac Feb 2, 2024
849fae9
fix geo_datasets.csv
yichiac Feb 2, 2024
0844abe
fix codecov
yichiac Feb 2, 2024
277a45d
Merge branch 'microsoft:main' into datasets/agrifieldnet
yichiac Feb 3, 2024
e12e79c
fix read tif issue
yichiac Feb 3, 2024
2434d6c
Update torchgeo/datasets/agrifieldnet.py
yichiac Feb 5, 2024
385e893
Merge branch 'microsoft:main' into datasets/agrifieldnet
yichiac Feb 6, 2024
1edfa12
fix init
yichiac Feb 6, 2024
a27223e
Merge branch 'microsoft:main' into datasets/agrifieldnet
yichiac Feb 8, 2024
b98d15e
add ordinal_cmap to pred and remove comments
yichiac Feb 8, 2024
18352ee
remove suffix
yichiac Feb 9, 2024
c2ae3bc
remove download entirely
yichiac Feb 9, 2024
ff6f1ff
style
yichiac Feb 9, 2024
230089a
Merge branch 'microsoft:main' into datasets/agrifieldnet
yichiac Feb 9, 2024
72e6ea8
Update agrifieldnet.py
yichiac Feb 12, 2024
04d2391
Update agrifieldnet.py
yichiac Feb 12, 2024
55e811c
remove url and if statement
yichiac Feb 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ Aboveground Woody Biomass

.. autoclass:: AbovegroundLiveWoodyBiomassDensity

AgriFieldNet
^^^^^^^^^^^^

.. autoclass:: AgriFieldNet

Airphen
^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/api/geo_datasets.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Dataset,Type,Source,License,Size (px),Resolution (m)
`Aboveground Woody Biomass`_,Masks,"Landsat, LiDAR","CC-BY-4.0","40,000x40,000",30
`AgriFieldNet`_,"Imagery, Masks",Sentinel-2,"CC-BY-4.0","256x256",10
`Airphen`_,Imagery,Airphen,-,"1,280x960",0.047--0.09
`Aster Global DEM`_,Masks,Aster,"public domain","3,601x3,601",30
`Canadian Building Footprints`_,Geometries,Bing Imagery,"ODbL-1.0",-,-
Expand Down
108 changes: 108 additions & 0 deletions tests/data/agrifieldnet/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine


def generate_test_data(paths: str) -> str:
"""Create test data archive for AgriFieldNet dataset.

Args:
paths: path to store test data
n_samples: number of samples.

Returns:
md5 hash of created archive
"""
dtype = np.uint8
dtype_max = np.iinfo(dtype).max

SIZE = 32

np.random.seed(0)

bands = (
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B8A",
"B09",
"B11",
"B12",
)

profile = {
"dtype": dtype,
"width": SIZE,
"height": SIZE,
"count": 1,
"crs": CRS.from_epsg(32644),
"transform": Affine(10.0, 0.0, 535840.0, 0.0, -10.0, 3079680.0),
}

source_dir = os.path.join(paths, "source")
train_mask_dir = os.path.join(paths, "train_labels")
test_field_dir = os.path.join(paths, "test_labels")

os.makedirs(source_dir, exist_ok=True)
os.makedirs(train_mask_dir, exist_ok=True)
os.makedirs(test_field_dir, exist_ok=True)

source_unique_folder_ids = ["32407", "8641e", "a419f", "eac11", "ff450"]
train_folder_ids = source_unique_folder_ids[0:5]
test_folder_ids = source_unique_folder_ids[3:5]

for id in source_unique_folder_ids:
directory = os.path.join(
source_dir, "ref_agrifieldnet_competition_v1_source_" + id
)
os.makedirs(directory, exist_ok=True)

for band in bands:
train_arr = np.random.randint(dtype_max, size=(SIZE, SIZE), dtype=dtype)
path = os.path.join(
directory, f"ref_agrifieldnet_competition_v1_source_{id}_{band}_10m.tif"
)
with rasterio.open(path, "w", **profile) as src:
src.write(train_arr, 1)

for id in train_folder_ids:
train_mask_arr = np.random.randint(size=(SIZE, SIZE), low=0, high=6)
path = os.path.join(
train_mask_dir, f"ref_agrifieldnet_competition_v1_labels_train_{id}.tif"
)
with rasterio.open(path, "w", **profile) as src:
src.write(train_mask_arr, 1)

train_field_arr = np.random.randint(20, size=(SIZE, SIZE), dtype=np.uint16)
path = os.path.join(
train_mask_dir,
f"ref_agrifieldnet_competition_v1_labels_train_{id}_field_ids.tif",
)
with rasterio.open(path, "w", **profile) as src:
src.write(train_field_arr, 1)

for id in test_folder_ids:
test_field_arr = np.random.randint(10, 30, size=(SIZE, SIZE), dtype=np.uint16)
path = os.path.join(
test_field_dir,
f"ref_agrifieldnet_competition_v1_labels_test_{id}_field_ids.tif",
)
with rasterio.open(path, "w", **profile) as src:
src.write(test_field_arr, 1)


if __name__ == "__main__":
generate_test_data(os.getcwd())
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
77 changes: 77 additions & 0 deletions tests/datasets/test_agrifieldnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from pathlib import Path

import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from rasterio.crs import CRS

from torchgeo.datasets import (
AgriFieldNet,
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
RGBBandsMissingError,
UnionDataset,
)


class TestAgriFieldNet:
@pytest.fixture
def dataset(self) -> AgriFieldNet:
path = os.path.join("tests", "data", "agrifieldnet")
transforms = nn.Identity()
return AgriFieldNet(paths=path, transforms=transforms)

def test_getitem(self, dataset: AgriFieldNet) -> None:
x = dataset[dataset.bounds]
assert isinstance(x, dict)
assert isinstance(x["crs"], CRS)
assert isinstance(x["image"], torch.Tensor)
assert isinstance(x["mask"], torch.Tensor)

def test_and(self, dataset: AgriFieldNet) -> None:
ds = dataset & dataset
assert isinstance(ds, IntersectionDataset)

def test_or(self, dataset: AgriFieldNet) -> None:
ds = dataset | dataset
assert isinstance(ds, UnionDataset)

def test_already_downloaded(self, dataset: AgriFieldNet) -> None:
AgriFieldNet(paths=dataset.paths)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
AgriFieldNet(str(tmp_path))

def test_plot(self, dataset: AgriFieldNet) -> None:
x = dataset[dataset.bounds]
dataset.plot(x, suptitle="Test")
plt.close()

def test_plot_prediction(self, dataset: AgriFieldNet) -> None:
x = dataset[dataset.bounds]
x["prediction"] = x["mask"].clone()
dataset.plot(x, suptitle="Prediction")
plt.close()

def test_invalid_query(self, dataset: AgriFieldNet) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
with pytest.raises(
IndexError, match="query: .* not found in index with bounds:"
):
dataset[query]

def test_rgb_bands_absent_plot(self, dataset: AgriFieldNet) -> None:
with pytest.raises(
RGBBandsMissingError, match="Dataset does not contain some of the RGB bands"
):
ds = AgriFieldNet(dataset.paths, bands=["B01", "B02", "B05"])
x = ds[ds.bounds]
ds.plot(x, suptitle="Test")
plt.close()
2 changes: 2 additions & 0 deletions torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .advance import ADVANCE
from .agb_live_woody_density import AbovegroundLiveWoodyBiomassDensity
from .agrifieldnet import AgriFieldNet
from .airphen import Airphen
from .astergdem import AsterGDEM
from .benin_cashews import BeninSmallHolderCashews
Expand Down Expand Up @@ -139,6 +140,7 @@
__all__ = (
# GeoDataset
"AbovegroundLiveWoodyBiomassDensity",
"AgriFieldNet",
"Airphen",
"AsterGDEM",
"CanadianBuildingFootprints",
Expand Down
Loading
Loading