Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add/unet #67

Merged
merged 19 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements/required.txt -r requirements/test.txt
mim install mmsegmentation
- name: List pip dependencies
run: pip list
- name: Test with pytest
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
requires = [ "setuptools" ]
build-backend = 'setuptools.build_meta'

# It allows installation via `pip install -e`
[tool.setuptools]
py-modules = []

[project]
name = "terratorch"
version = "0.99.1"
Expand Down
8 changes: 7 additions & 1 deletion requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,10 @@ lightly==1.4.25
h5py==3.10.0
geobench==1.0.0
mlflow==2.14.3
lightning==2.2.5
lightning==2.2.5
mmcv==2.1.0
# Extra dependencies required by mmseg
ftfy
regex
openmim
#mim mmsegmentation
2 changes: 2 additions & 0 deletions terratorch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from terratorch.models.scalemae_model_factory import ScaleMAEModelFactory
from terratorch.models.smp_model_factory import SMPModelFactory
from terratorch.models.timm_model_factory import TimmModelFactory
from terratorch.models.generic_unet_model_factory import GenericUnetModelFactory

__all__ = (
"PrithviModelFactory",
"ClayModelFactory",
"SatMAEModelFactory",
"ScaleMAEModelFactory",
"SMPModelFactory",
"GenericUnetModelFactory",
"TimmModelFactory",
"AuxiliaryHead",
"AuxiliaryHeadWithDecoderWithoutInstantiatedHead",
Expand Down
117 changes: 117 additions & 0 deletions terratorch/models/generic_unet_model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright contributors to the Terratorch project

"""
This is just an example of a possible structure to include SMP models
Right now it always returns a UNET, but could easily be extended to many of the models provided by SMP.
"""

from torch import nn
import torch
from terratorch.models.model import Model, ModelFactory, ModelOutput, register_factory
from terratorch.tasks.segmentation_tasks import to_segmentation_prediction

import importlib

def freeze_module(module: nn.Module):
for param in module.parameters():
param.requires_grad_(False)

@register_factory
class GenericUnetModelFactory(ModelFactory):
def build_model(
self,
task: str = "segmentation",
backbone: str = None,
decoder: str = None,
dilations: tuple[int] = (1, 6, 12, 18),
in_channels: int = 6,
pretrained: str | bool | None = True,
num_classes: int = 1,
regression_relu: bool = False,
**kwargs,
) -> Model:
"""Factory to create model based on SMP.
Args:
task (str): Must be "segmentation".
model (str): Decoder architecture. Currently only supports "unet".
in_channels (int): Number of input channels.
pretrained(str | bool): Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True.
num_classes (int): Number of classes.
regression_relu (bool). Whether to apply a ReLU if task is regression. Defaults to False.
Returns:
Model: SMP model wrapped in SMPModelWrapper.
"""
if task not in ["segmentation", "regression"]:
msg = f"SMP models can only perform pixel wise tasks, but got task {task}"
raise Exception(msg)

mmseg_decoders = importlib.import_module("mmseg.models.decode_heads")
mmseg_encoders = importlib.import_module("mmseg.models.backbones")

if backbone:
backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_")
model = backbone
model_kwargs = backbone_kwargs
mmseg = mmseg_encoders
elif decoder:
decoder_kwargs = _extract_prefix_keys(kwargs, "decoder_")
model = decoder
model_kwargs = decoder_kwargs
mmseg = mmseg_decoders
else:
print("It is necessary to define a backbone and/or a decoder.")

model_class = getattr(mmseg, model)

model = model_class(
**model_kwargs,
)

return GenericUnetModelWrapper(
model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
)

class GenericUnetModelWrapper(Model, nn.Module):
def __init__(self, unet_model, relu=False, squeeze_single_class=False) -> None:
super().__init__()
self.unet_model = unet_model
self.final_act = nn.ReLU() if relu else nn.Identity()
self.squeeze_single_class = squeeze_single_class

def forward(self, *args, **kwargs):

# It supposes the input has dimension (B, C, H, W)
input_data = [args[0]] # It adapts the input to became a list of time 'snapshots'
args = (input_data,)

unet_output = self.unet_model(*args, **kwargs)
unet_output = self.final_act(unet_output)

if unet_output.shape[1] == 1 and self.squeeze_single_class:
unet_output = unet_output.squeeze(1)

model_output = ModelOutput(unet_output)

return model_output

def freeze_encoder(self):
raise NotImplementedError()

def freeze_decoder(self):
raise freeze_module(self.unet_model)


def _extract_prefix_keys(d: dict, prefix: str) -> dict:
extracted_dict = {}
keys_to_del = []
for k, v in d.items():
if k.startswith(prefix):
extracted_dict[k.split(prefix)[1]] = v
keys_to_del.append(k)

for k in keys_to_del:
del d[k]

return extracted_dict
148 changes: 148 additions & 0 deletions tests/manufactured-finetune_aspphead.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
#precision: 16-mixed
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 5
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 2
num_workers: 4
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
# mask_value: 1
p: 0.5
- class_path: ToTensorV2
dataset_bands:
- COASTAL_AEROSOL
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- CIRRUS
- THEMRAL_INFRARED_1
- THEMRAL_INFRARED_2
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
img_grep: "segmentation*input*.tif"
label_grep: "segmentation*label*.tif"
means:
- 547.36707
- 898.5121
- 1020.9082
- 2665.5352
- 2340.584
- 1610.1407
stds:
- 411.4701
- 558.54065
- 815.94025
- 812.4403
- 1113.7145
- 1067.641
no_label_replace: -1
no_data_replace: 0
num_classes: 2
model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: "ASPPHead"
decoder_dilations: [1, 6, 12, 18]
decoder_channels: 256
decoder_in_channels: 6
decoder_num_classes: 2
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
#num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
num_classes: 2
loss: ce
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
#freeze_encoder: false #true
#freeze_decoder: false
model_factory: GenericUnetModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

Loading
Loading