-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #67 from IBM/add/unet
Add/unet
- Loading branch information
Showing
16 changed files
with
451 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright contributors to the Terratorch project | ||
|
||
""" | ||
This is just an example of a possible structure to include SMP models | ||
Right now it always returns a UNET, but could easily be extended to many of the models provided by SMP. | ||
""" | ||
|
||
from torch import nn | ||
import torch | ||
from terratorch.models.model import Model, ModelFactory, ModelOutput, register_factory | ||
from terratorch.tasks.segmentation_tasks import to_segmentation_prediction | ||
|
||
import importlib | ||
|
||
def freeze_module(module: nn.Module): | ||
for param in module.parameters(): | ||
param.requires_grad_(False) | ||
|
||
@register_factory | ||
class GenericUnetModelFactory(ModelFactory): | ||
def build_model( | ||
self, | ||
task: str = "segmentation", | ||
backbone: str = None, | ||
decoder: str = None, | ||
dilations: tuple[int] = (1, 6, 12, 18), | ||
in_channels: int = 6, | ||
pretrained: str | bool | None = True, | ||
num_classes: int = 1, | ||
regression_relu: bool = False, | ||
**kwargs, | ||
) -> Model: | ||
"""Factory to create model based on SMP. | ||
Args: | ||
task (str): Must be "segmentation". | ||
model (str): Decoder architecture. Currently only supports "unet". | ||
in_channels (int): Number of input channels. | ||
pretrained(str | bool): Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True. | ||
num_classes (int): Number of classes. | ||
regression_relu (bool). Whether to apply a ReLU if task is regression. Defaults to False. | ||
Returns: | ||
Model: SMP model wrapped in SMPModelWrapper. | ||
""" | ||
if task not in ["segmentation", "regression"]: | ||
msg = f"SMP models can only perform pixel wise tasks, but got task {task}" | ||
raise Exception(msg) | ||
|
||
mmseg_decoders = importlib.import_module("mmseg.models.decode_heads") | ||
mmseg_encoders = importlib.import_module("mmseg.models.backbones") | ||
|
||
if backbone: | ||
backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_") | ||
model = backbone | ||
model_kwargs = backbone_kwargs | ||
mmseg = mmseg_encoders | ||
elif decoder: | ||
decoder_kwargs = _extract_prefix_keys(kwargs, "decoder_") | ||
model = decoder | ||
model_kwargs = decoder_kwargs | ||
mmseg = mmseg_decoders | ||
else: | ||
print("It is necessary to define a backbone and/or a decoder.") | ||
|
||
model_class = getattr(mmseg, model) | ||
|
||
model = model_class( | ||
**model_kwargs, | ||
) | ||
|
||
return GenericUnetModelWrapper( | ||
model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression" | ||
) | ||
|
||
class GenericUnetModelWrapper(Model, nn.Module): | ||
def __init__(self, unet_model, relu=False, squeeze_single_class=False) -> None: | ||
super().__init__() | ||
self.unet_model = unet_model | ||
self.final_act = nn.ReLU() if relu else nn.Identity() | ||
self.squeeze_single_class = squeeze_single_class | ||
|
||
def forward(self, *args, **kwargs): | ||
|
||
# It supposes the input has dimension (B, C, H, W) | ||
input_data = [args[0]] # It adapts the input to became a list of time 'snapshots' | ||
args = (input_data,) | ||
|
||
unet_output = self.unet_model(*args, **kwargs) | ||
unet_output = self.final_act(unet_output) | ||
|
||
if unet_output.shape[1] == 1 and self.squeeze_single_class: | ||
unet_output = unet_output.squeeze(1) | ||
|
||
model_output = ModelOutput(unet_output) | ||
|
||
return model_output | ||
|
||
def freeze_encoder(self): | ||
raise NotImplementedError() | ||
|
||
def freeze_decoder(self): | ||
raise freeze_module(self.unet_model) | ||
|
||
|
||
def _extract_prefix_keys(d: dict, prefix: str) -> dict: | ||
extracted_dict = {} | ||
keys_to_del = [] | ||
for k, v in d.items(): | ||
if k.startswith(prefix): | ||
extracted_dict[k.split(prefix)[1]] = v | ||
keys_to_del.append(k) | ||
|
||
for k in keys_to_del: | ||
del d[k] | ||
|
||
return extracted_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# lightning.pytorch==2.1.1 | ||
seed_everything: 42 | ||
trainer: | ||
accelerator: auto | ||
strategy: auto | ||
devices: auto | ||
num_nodes: 1 | ||
#precision: 16-mixed | ||
# precision: 16-mixed | ||
logger: | ||
class_path: TensorBoardLogger | ||
init_args: | ||
save_dir: tests/ | ||
name: all_ecos_random | ||
callbacks: | ||
- class_path: RichProgressBar | ||
- class_path: LearningRateMonitor | ||
init_args: | ||
logging_interval: epoch | ||
- class_path: EarlyStopping | ||
init_args: | ||
monitor: val/loss | ||
patience: 100 | ||
max_epochs: 5 | ||
check_val_every_n_epoch: 1 | ||
log_every_n_steps: 20 | ||
enable_checkpointing: true | ||
default_root_dir: tests/ | ||
data: | ||
class_path: GenericNonGeoSegmentationDataModule | ||
init_args: | ||
batch_size: 2 | ||
num_workers: 4 | ||
train_transform: | ||
- class_path: albumentations.HorizontalFlip | ||
init_args: | ||
p: 0.5 | ||
- class_path: albumentations.Rotate | ||
init_args: | ||
limit: 30 | ||
border_mode: 0 # cv2.BORDER_CONSTANT | ||
value: 0 | ||
# mask_value: 1 | ||
p: 0.5 | ||
- class_path: ToTensorV2 | ||
dataset_bands: | ||
- COASTAL_AEROSOL | ||
- BLUE | ||
- GREEN | ||
- RED | ||
- NIR_NARROW | ||
- SWIR_1 | ||
- SWIR_2 | ||
- CIRRUS | ||
- THEMRAL_INFRARED_1 | ||
- THEMRAL_INFRARED_2 | ||
output_bands: | ||
- BLUE | ||
- GREEN | ||
- RED | ||
- NIR_NARROW | ||
- SWIR_1 | ||
- SWIR_2 | ||
rgb_indices: | ||
- 2 | ||
- 1 | ||
- 0 | ||
train_data_root: tests/ | ||
train_label_data_root: tests/ | ||
val_data_root: tests/ | ||
val_label_data_root: tests/ | ||
test_data_root: tests/ | ||
test_label_data_root: tests/ | ||
img_grep: "segmentation*input*.tif" | ||
label_grep: "segmentation*label*.tif" | ||
means: | ||
- 547.36707 | ||
- 898.5121 | ||
- 1020.9082 | ||
- 2665.5352 | ||
- 2340.584 | ||
- 1610.1407 | ||
stds: | ||
- 411.4701 | ||
- 558.54065 | ||
- 815.94025 | ||
- 812.4403 | ||
- 1113.7145 | ||
- 1067.641 | ||
no_label_replace: -1 | ||
no_data_replace: 0 | ||
num_classes: 2 | ||
model: | ||
class_path: terratorch.tasks.SemanticSegmentationTask | ||
init_args: | ||
model_args: | ||
decoder: "ASPPHead" | ||
decoder_dilations: [1, 6, 12, 18] | ||
decoder_channels: 256 | ||
decoder_in_channels: 6 | ||
decoder_num_classes: 2 | ||
in_channels: 6 | ||
bands: | ||
- BLUE | ||
- GREEN | ||
- RED | ||
- NIR_NARROW | ||
- SWIR_1 | ||
- SWIR_2 | ||
#num_frames: 1 | ||
head_dropout: 0.5708022831486758 | ||
head_final_act: torch.nn.ReLU | ||
head_learned_upscale_layers: 2 | ||
num_classes: 2 | ||
loss: ce | ||
#aux_heads: | ||
# - name: aux_head | ||
# decoder: IdentityDecoder | ||
# decoder_args: | ||
# decoder_out_index: 2 | ||
# head_dropout: 0,5 | ||
# head_channel_list: | ||
# - 64 | ||
# head_final_act: torch.nn.ReLU | ||
#aux_loss: | ||
# aux_head: 0.4 | ||
ignore_index: -1 | ||
#freeze_encoder: false #true | ||
#freeze_decoder: false | ||
model_factory: GenericUnetModelFactory | ||
|
||
# uncomment this block for tiled inference | ||
# tiled_inference_parameters: | ||
# h_crop: 224 | ||
# h_stride: 192 | ||
# w_crop: 224 | ||
# w_stride: 192 | ||
# average_patches: true | ||
optimizer: | ||
class_path: torch.optim.AdamW | ||
init_args: | ||
lr: 0.00013524680528283027 | ||
weight_decay: 0.047782217873995426 | ||
lr_scheduler: | ||
class_path: ReduceLROnPlateau | ||
init_args: | ||
monitor: val/loss | ||
|
Oops, something went wrong.