Skip to content

Commit

Permalink
Merge pull request #67 from IBM/add/unet
Browse files Browse the repository at this point in the history
Add/unet
  • Loading branch information
Joao-L-S-Almeida authored Aug 20, 2024
2 parents a09f8e9 + 9f7eca0 commit cbc98f1
Show file tree
Hide file tree
Showing 16 changed files with 451 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements/required.txt -r requirements/test.txt
mim install mmsegmentation
- name: List pip dependencies
run: pip list
- name: Test with pytest
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
requires = [ "setuptools" ]
build-backend = 'setuptools.build_meta'

# It allows installation via `pip install -e`
[tool.setuptools]
py-modules = []

[project]
name = "terratorch"
version = "0.99.1"
Expand Down
8 changes: 7 additions & 1 deletion requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,10 @@ lightly==1.4.25
h5py==3.10.0
geobench==1.0.0
mlflow==2.14.3
lightning==2.2.5
lightning==2.2.5
mmcv==2.1.0
# Extra dependencies required by mmseg
ftfy
regex
openmim
#mim mmsegmentation
2 changes: 2 additions & 0 deletions terratorch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from terratorch.models.scalemae_model_factory import ScaleMAEModelFactory
from terratorch.models.smp_model_factory import SMPModelFactory
from terratorch.models.timm_model_factory import TimmModelFactory
from terratorch.models.generic_unet_model_factory import GenericUnetModelFactory

__all__ = (
"PrithviModelFactory",
"ClayModelFactory",
"SatMAEModelFactory",
"ScaleMAEModelFactory",
"SMPModelFactory",
"GenericUnetModelFactory",
"TimmModelFactory",
"AuxiliaryHead",
"AuxiliaryHeadWithDecoderWithoutInstantiatedHead",
Expand Down
117 changes: 117 additions & 0 deletions terratorch/models/generic_unet_model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright contributors to the Terratorch project

"""
This is just an example of a possible structure to include SMP models
Right now it always returns a UNET, but could easily be extended to many of the models provided by SMP.
"""

from torch import nn
import torch
from terratorch.models.model import Model, ModelFactory, ModelOutput, register_factory
from terratorch.tasks.segmentation_tasks import to_segmentation_prediction

import importlib

def freeze_module(module: nn.Module):
for param in module.parameters():
param.requires_grad_(False)

@register_factory
class GenericUnetModelFactory(ModelFactory):
def build_model(
self,
task: str = "segmentation",
backbone: str = None,
decoder: str = None,
dilations: tuple[int] = (1, 6, 12, 18),
in_channels: int = 6,
pretrained: str | bool | None = True,
num_classes: int = 1,
regression_relu: bool = False,
**kwargs,
) -> Model:
"""Factory to create model based on SMP.
Args:
task (str): Must be "segmentation".
model (str): Decoder architecture. Currently only supports "unet".
in_channels (int): Number of input channels.
pretrained(str | bool): Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True.
num_classes (int): Number of classes.
regression_relu (bool). Whether to apply a ReLU if task is regression. Defaults to False.
Returns:
Model: SMP model wrapped in SMPModelWrapper.
"""
if task not in ["segmentation", "regression"]:
msg = f"SMP models can only perform pixel wise tasks, but got task {task}"
raise Exception(msg)

mmseg_decoders = importlib.import_module("mmseg.models.decode_heads")
mmseg_encoders = importlib.import_module("mmseg.models.backbones")

if backbone:
backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_")
model = backbone
model_kwargs = backbone_kwargs
mmseg = mmseg_encoders
elif decoder:
decoder_kwargs = _extract_prefix_keys(kwargs, "decoder_")
model = decoder
model_kwargs = decoder_kwargs
mmseg = mmseg_decoders
else:
print("It is necessary to define a backbone and/or a decoder.")

model_class = getattr(mmseg, model)

model = model_class(
**model_kwargs,
)

return GenericUnetModelWrapper(
model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
)

class GenericUnetModelWrapper(Model, nn.Module):
def __init__(self, unet_model, relu=False, squeeze_single_class=False) -> None:
super().__init__()
self.unet_model = unet_model
self.final_act = nn.ReLU() if relu else nn.Identity()
self.squeeze_single_class = squeeze_single_class

def forward(self, *args, **kwargs):

# It supposes the input has dimension (B, C, H, W)
input_data = [args[0]] # It adapts the input to became a list of time 'snapshots'
args = (input_data,)

unet_output = self.unet_model(*args, **kwargs)
unet_output = self.final_act(unet_output)

if unet_output.shape[1] == 1 and self.squeeze_single_class:
unet_output = unet_output.squeeze(1)

model_output = ModelOutput(unet_output)

return model_output

def freeze_encoder(self):
raise NotImplementedError()

def freeze_decoder(self):
raise freeze_module(self.unet_model)


def _extract_prefix_keys(d: dict, prefix: str) -> dict:
extracted_dict = {}
keys_to_del = []
for k, v in d.items():
if k.startswith(prefix):
extracted_dict[k.split(prefix)[1]] = v
keys_to_del.append(k)

for k in keys_to_del:
del d[k]

return extracted_dict
148 changes: 148 additions & 0 deletions tests/manufactured-finetune_aspphead.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# lightning.pytorch==2.1.1
seed_everything: 42
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
#precision: 16-mixed
# precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: tests/
name: all_ecos_random
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 100
max_epochs: 5
check_val_every_n_epoch: 1
log_every_n_steps: 20
enable_checkpointing: true
default_root_dir: tests/
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 2
num_workers: 4
train_transform:
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: albumentations.Rotate
init_args:
limit: 30
border_mode: 0 # cv2.BORDER_CONSTANT
value: 0
# mask_value: 1
p: 0.5
- class_path: ToTensorV2
dataset_bands:
- COASTAL_AEROSOL
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
- CIRRUS
- THEMRAL_INFRARED_1
- THEMRAL_INFRARED_2
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
train_data_root: tests/
train_label_data_root: tests/
val_data_root: tests/
val_label_data_root: tests/
test_data_root: tests/
test_label_data_root: tests/
img_grep: "segmentation*input*.tif"
label_grep: "segmentation*label*.tif"
means:
- 547.36707
- 898.5121
- 1020.9082
- 2665.5352
- 2340.584
- 1610.1407
stds:
- 411.4701
- 558.54065
- 815.94025
- 812.4403
- 1113.7145
- 1067.641
no_label_replace: -1
no_data_replace: 0
num_classes: 2
model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: "ASPPHead"
decoder_dilations: [1, 6, 12, 18]
decoder_channels: 256
decoder_in_channels: 6
decoder_num_classes: 2
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
#num_frames: 1
head_dropout: 0.5708022831486758
head_final_act: torch.nn.ReLU
head_learned_upscale_layers: 2
num_classes: 2
loss: ce
#aux_heads:
# - name: aux_head
# decoder: IdentityDecoder
# decoder_args:
# decoder_out_index: 2
# head_dropout: 0,5
# head_channel_list:
# - 64
# head_final_act: torch.nn.ReLU
#aux_loss:
# aux_head: 0.4
ignore_index: -1
#freeze_encoder: false #true
#freeze_decoder: false
model_factory: GenericUnetModelFactory

# uncomment this block for tiled inference
# tiled_inference_parameters:
# h_crop: 224
# h_stride: 192
# w_crop: 224
# w_stride: 192
# average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.00013524680528283027
weight_decay: 0.047782217873995426
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

Loading

0 comments on commit cbc98f1

Please sign in to comment.