Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added support for temporal segmentation data in encoder decoder factory #355

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion terratorch/models/encoder_decoder_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright contributors to the Terratorch project


import torch
from torch import nn

from terratorch.models.model import (
Expand All @@ -20,6 +21,31 @@
SUPPORTED_TASKS = PIXEL_WISE_TASKS + SCALAR_TASKS



class TemporalWrapper(nn.Module):
def __init__(self, encoder, pooling="mean"):
super().__init__()
self.encoder = encoder
if pooling == "mean":
self.pooling = torch.mean
elif pooling == "max":
self.pooling = torch.max
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have a method concat which merges the embeddings of all timestamps along the embedding dim. E.g. for testing how much accuracy we loose if the time stamps are averaged before the decoder.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires fix time stamps defined by the user so that the decoder gets the correct out_channels.

msg = "Pooling must be 'mean' or 'max'"
raise ValueError(msg)

def forward(self, x):
# x is a list of tensors, each corresponding to a different timestamp
features = [self.encoder(t) for t in x]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your approach, thanks for starting this draft!

Other models in terratorch process data in the format [B, C, T, H, W]. That is also the format how data is provided from the generic data modules. It might be good to follow this this pattern und iterate over dim=2 instead of expecting a list.

# Stack features along a new dimension and apply pooling
features = torch.stack(features, dim=0)
if self.pooling == torch.max:
pooled_features, _ = self.pooling(features, dim=0)
else:
pooled_features = self.pooling(features, dim=0)
return pooled_features


def _get_backbone(backbone: str | nn.Module, **backbone_kwargs) -> nn.Module:
if isinstance(backbone, nn.Module):
return backbone
Expand Down Expand Up @@ -73,7 +99,9 @@ def build_model(
num_classes: int | None = None,
necks: list[dict] | None = None,
aux_decoders: list[AuxiliaryHead] | None = None,
rescale: bool = True, # noqa: FBT002, FBT001
rescale: bool = True, # noqa: FBT002, FBT001,
use_temporal: bool = False,
temporal_pooling: str = "mean",
**kwargs,
) -> Model:
"""Generic model factory that combines an encoder and decoder, together with a head, for a specific task.
Expand Down Expand Up @@ -136,6 +164,10 @@ def build_model(
decoder, channel_list, decoder_kwargs, head_kwargs, num_classes=num_classes
)

# Add temporal wrapper if enabled
if use_temporal:
backbone = TemporalWrapper(backbone, pooling=temporal_pooling)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would apply the wrapper when building the backbone. I.e. backbone_use_temporal is passed as use_temporal to _get_backbone. The only important thing is that you save the backbone.out_channels in your Wrapper as self.out_channels as well (For concat you have to modify it as well).


if aux_decoders is None:
_check_all_args_used(kwargs)
return _build_appropriate_model(task, backbone, decoder, head_kwargs, necks=neck_list, decoder_includes_head=decoder_includes_head, rescale=rescale)
Expand Down