diff --git a/terratorch/models/encoder_decoder_factory.py b/terratorch/models/encoder_decoder_factory.py index efd018c3..2e4971f3 100644 --- a/terratorch/models/encoder_decoder_factory.py +++ b/terratorch/models/encoder_decoder_factory.py @@ -1,5 +1,7 @@ # Copyright contributors to the Terratorch project + +import torch from typing import List import warnings import logging @@ -23,10 +25,91 @@ SUPPORTED_TASKS = PIXEL_WISE_TASKS + SCALAR_TASKS -def _get_backbone(backbone: str | nn.Module, **backbone_kwargs) -> nn.Module: + +class TemporalWrapper(nn.Module): + def __init__(self, encoder: nn.Module, pooling="mean", concat=False): + """ + Wrapper for applying a temporal encoder across multiple time steps. + + Args: + encoder (nn.Module): The feature extractor (backbone). + pooling (str): Type of pooling ('mean' or 'max'). + concat (bool): Whether to concatenate features instead of pooling. + """ + super().__init__() + self.encoder = encoder + self.concat = concat + self.pooling_type = pooling + + if pooling not in ["mean", "max"]: + raise ValueError("Pooling must be 'mean' or 'max'") + + # Ensure the encoder has an out_channels attribute + if hasattr(encoder, "out_channels"): + self.out_channels = encoder.out_channels * (1 if not concat else encoder.out_channels) + else: + raise AttributeError("Encoder must have an `out_channels` attribute.") + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + """ + Forward pass for temporal processing. + + Args: + x (Tensor): Input tensor of shape [B, C, T, H, W]. + + Returns: + List[Tensor]: A list of processed tensors, one per feature map. + """ + if x.dim() != 5: + raise ValueError(f"Expected input shape [B, C, T, H, W], but got {x.shape}") + + batch_size, channels, timesteps, height, width = x.shape + + # Initialize lists to store feature maps at each timestamp + num_feature_maps = None # Will determine dynamically + features_per_map = [] # Stores feature maps across timestamps + + for t in range(timesteps): + feat = self.encoder(x[:, :, t, :, :]) # Extract features at timestamp t + + if not isinstance(feat, list): # If the encoder outputs a single feature map, convert to list + feat = [feat] + + if num_feature_maps is None: + num_feature_maps = len(feat) # Determine how many feature maps the encoder produces + + for i, feature_map in enumerate(feat): + if len(features_per_map) <= i: + features_per_map.append([]) # Create list for each feature map + + features_per_map[i].append(feature_map) # Store feature map at time t + + # Stack features along the temporal dimension + for i in range(num_feature_maps): + try: + features_per_map[i] = torch.stack(features_per_map[i], dim=2) # Shape: [B, C', T, H', W'] + except RuntimeError as e: + raise + + # Apply pooling or concatenation + if self.concat: + return [feat.view(batch_size, -1, height, width) for feat in features_per_map] # Flatten T into C' + elif self.pooling_type == "max": + return [torch.max(feat, dim=2)[0] for feat in features_per_map] # Max pooling across T + else: + return [torch.mean(feat, dim=2) for feat in features_per_map] # Mean pooling across T + +def _get_backbone(backbone: str | nn.Module, use_temporal=False, temporal_pooling="mean", **backbone_kwargs) -> nn.Module: if isinstance(backbone, nn.Module): - return backbone - return BACKBONE_REGISTRY.build(backbone, **backbone_kwargs) + model = backbone + else: + model = BACKBONE_REGISTRY.build(backbone, **backbone_kwargs) + + # Apply TemporalWrapper inside _get_backbone + if use_temporal: + model = TemporalWrapper(model, pooling=temporal_pooling) + + return model def _get_decoder_and_head_kwargs( @@ -78,7 +161,9 @@ def build_model( num_classes: int | None = None, necks: list[dict] | None = None, aux_decoders: list[AuxiliaryHead] | None = None, - rescale: bool = True, # noqa: FBT002, FBT001 + rescale: bool = True, # noqa: FBT002, FBT001, + backbone_use_temporal: bool = False, + temporal_pooling: str = "mean", peft_config: dict | None = None, **kwargs, ) -> Model: @@ -128,7 +213,7 @@ def build_model( raise NotImplementedError(msg) backbone_kwargs, kwargs = extract_prefix_keys(kwargs, "backbone_") - backbone = _get_backbone(backbone, **backbone_kwargs) + backbone = _get_backbone(backbone, use_temporal=backbone_use_temporal, temporal_pooling = temporal_pooling, **backbone_kwargs) # If patch size is not provided in the config or by the model, it might lead to errors due to irregular images. patch_size = backbone_kwargs.get("patch_size", None)