Skip to content

Commit

Permalink
Merge branch 'main' into update-metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
robmarkcole authored Jul 2, 2024
2 parents 8ce8c30 + 0448d6d commit e4ed9fd
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 61 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ torchgeo fit --config config.yaml
# Validate-only
torchgeo validate --config config.yaml
# Calculate and report test accuracy
torchgeo test --config config.yaml ckpt_path=...
torchgeo test --config config.yaml --ckpt_path=...
```

It can also be imported and used in a Python script if you need to extend it to add new features:
Expand Down
16 changes: 0 additions & 16 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,6 @@
:class: colabbadge
:alt: Open in Colab
:target: {{ host }}/github/{{ repo }}/blob/{{ branch }}/{{ urlpath }}
{% set host = "https://pccompute.westeurope.cloudapp.azure.com" %}
{% set host = host ~ "/compute/hub/user-redirect/git-pull" %}
{% set repo = "https%3A%2F%2Fgithub.com%2Fmicrosoft%2Ftorchgeo" %}
{% set urlpath = "tree%2Ftorchgeo%2Fdocs%2F" %}
{% set urlpath = urlpath ~ env.docname | replace("/", "%2F") ~ ".ipynb" %}
{% if "dev" in env.config.release %}
{% set branch = "main" %}
{% else %}
{% set branch = "releases%2Fv" ~ env.config.version %}
{% endif %}
.. image:: https://img.shields.io/badge/-Open%20on%20Planetary%20Computer-blue
:class: colabbadge
:alt: Open on Planetary Computer
:target: {{ host }}?repo={{ repo }}&urlpath={{ urlpath }}&branch={{ branch }}
"""

# Disables requirejs in nbsphinx to enable compatibility with the pytorch_sphinx_theme
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ dependencies = [
"einops>=0.3",
# fiona 1.8.21+ required for Python 3.10 wheels
"fiona>=1.8.21",
# kornia 0.7.2+ required for dict support in AugmentationSequential
"kornia>=0.7.2",
# kornia 0.7.3+ required for instance segmentation support in AugmentationSequential
"kornia>=0.7.3",
# lightly 1.4.4+ required for MoCo v3 support
# lightly 1.4.26 is incompatible with the version of timm required by smp
# https://github.com/microsoft/torchgeo/issues/1824
Expand Down
2 changes: 1 addition & 1 deletion requirements/datasets.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ pyvista==0.43.10
radiant-mlhub==0.4.1
rarfile==4.2
scikit-image==0.24.0
scipy==1.13.1
scipy==1.14.0
zipfile-deflate64==0.2.0
2 changes: 1 addition & 1 deletion requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ setuptools==61.0.0
# install
einops==0.3.0
fiona==1.8.21
kornia==0.7.2
kornia==0.7.3
lightly==1.4.4
lightning[pytorch-extra]==2.0.0
matplotlib==3.5.0
Expand Down
12 changes: 9 additions & 3 deletions requirements/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion requirements/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
"name": "torchgeo",
"private": "true",
"dependencies": {
"prettier": ">=3.0.0"
"prettier": ">=3.3.2"
}
}
6 changes: 3 additions & 3 deletions requirements/required.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# setup
setuptools==70.1.0
setuptools==70.2.0

# install
einops==0.8.0
fiona==1.9.6
kornia==0.7.2
lightly==1.5.7
kornia==0.7.3
lightly==1.5.8
lightning[pytorch-extra]==2.2.5
matplotlib==3.9.0
numpy==1.26.4
Expand Down
4 changes: 2 additions & 2 deletions requirements/style.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# style
mypy==1.10.0
ruff==0.4.10
mypy==1.10.1
ruff==0.5.0
15 changes: 9 additions & 6 deletions tests/transforms/test_color.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import kornia.augmentation as K
import pytest
import torch
from torch import Tensor

from torchgeo.transforms import AugmentationSequential, RandomGrayscale
from torchgeo.transforms import RandomGrayscale


@pytest.fixture
Expand Down Expand Up @@ -33,12 +34,15 @@ def batch() -> dict[str, Tensor]:
],
)
def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) -> None:
aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image'])
aug = K.AugmentationSequential(
RandomGrayscale(weights, p=1), keepdim=True, data_keys=None
)
# https://github.com/kornia/kornia/issues/2848
aug.keepdim = True
output = aug(sample)
assert output['image'].shape == sample['image'].shape
assert output['image'].sum() == sample['image'].sum()
for i in range(1, 3):
assert torch.allclose(output['image'][0, 0], output['image'][0, i])
assert torch.allclose(output['image'][0], output['image'][i])


@pytest.mark.parametrize(
Expand All @@ -50,9 +54,8 @@ def test_random_grayscale_sample(weights: Tensor, sample: dict[str, Tensor]) ->
],
)
def test_random_grayscale_batch(weights: Tensor, batch: dict[str, Tensor]) -> None:
aug = AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=['image'])
aug = K.AugmentationSequential(RandomGrayscale(weights, p=1), data_keys=None)
output = aug(batch)
assert output['image'].shape == batch['image'].shape
assert output['image'].sum() == batch['image'].sum()
for i in range(1, 3):
assert torch.allclose(output['image'][0, 0], output['image'][0, i])
20 changes: 9 additions & 11 deletions tests/transforms/test_indices.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import kornia.augmentation as K
import pytest
import torch
from torch import Tensor
Expand All @@ -20,7 +21,6 @@
AppendRBNDVI,
AppendSWI,
AppendTriBandNormalizedDifferenceIndex,
AugmentationSequential,
)


Expand All @@ -42,29 +42,27 @@ def batch() -> dict[str, Tensor]:

def test_append_index_sample(sample: dict[str, Tensor]) -> None:
c, h, w = sample['image'].shape
aug = AugmentationSequential(
AppendNormalizedDifferenceIndex(index_a=0, index_b=1),
data_keys=['image', 'mask'],
aug = K.AugmentationSequential(
AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None
)
output = aug(sample)
assert output['image'].shape == (1, c + 1, h, w)


def test_append_index_batch(batch: dict[str, Tensor]) -> None:
b, c, h, w = batch['image'].shape
aug = AugmentationSequential(
AppendNormalizedDifferenceIndex(index_a=0, index_b=1),
data_keys=['image', 'mask'],
aug = K.AugmentationSequential(
AppendNormalizedDifferenceIndex(index_a=0, index_b=1), data_keys=None
)
output = aug(batch)
assert output['image'].shape == (b, c + 1, h, w)


def test_append_triband_index_batch(batch: dict[str, Tensor]) -> None:
b, c, h, w = batch['image'].shape
aug = AugmentationSequential(
aug = K.AugmentationSequential(
AppendTriBandNormalizedDifferenceIndex(index_a=0, index_b=1, index_c=2),
data_keys=['image', 'mask'],
data_keys=None,
)
output = aug(batch)
assert output['image'].shape == (b, c + 1, h, w)
Expand All @@ -88,7 +86,7 @@ def test_append_normalized_difference_indices(
sample: dict[str, Tensor], index: AppendNormalizedDifferenceIndex
) -> None:
c, h, w = sample['image'].shape
aug = AugmentationSequential(index(0, 1), data_keys=['image', 'mask'])
aug = K.AugmentationSequential(index(0, 1), data_keys=None)
output = aug(sample)
assert output['image'].shape == (1, c + 1, h, w)

Expand All @@ -98,6 +96,6 @@ def test_append_tri_band_normalized_difference_indices(
sample: dict[str, Tensor], index: AppendTriBandNormalizedDifferenceIndex
) -> None:
c, h, w = sample['image'].shape
aug = AugmentationSequential(index(0, 1, 2), data_keys=['image', 'mask'])
aug = K.AugmentationSequential(index(0, 1, 2), data_keys=None)
output = aug(sample)
assert output['image'].shape == (1, c + 1, h, w)
28 changes: 14 additions & 14 deletions torchgeo/datasets/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,20 @@ class AgriFieldNet(RasterDataset):
Dataset classes:
0 - No-Data
1 - Wheat
2 - Mustard
3 - Lentil
4 - No Crop/Fallow
5 - Green pea
6 - Sugarcane
8 - Garlic
9 - Maize
13 - Gram
14 - Coriander
15 - Potato
16 - Berseem
36 - Rice
* 0. No-Data
* 1. Wheat
* 2. Mustard
* 3. Lentil
* 4. No Crop/Fallow
* 5. Green pea
* 6. Sugarcane
* 8. Garlic
* 9. Maize
* 13. Gram
* 14. Coriander
* 15. Potato
* 16. Berseem
* 36. Rice
If you use this dataset in your research, please cite the following dataset:
Expand Down

0 comments on commit e4ed9fd

Please sign in to comment.