Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

out_channels and ReshapeTokensToImage for temporal tasks #402

Open
daniszw opened this issue Feb 4, 2025 · 1 comment
Open

out_channels and ReshapeTokensToImage for temporal tasks #402

daniszw opened this issue Feb 4, 2025 · 1 comment

Comments

@daniszw
Copy link
Contributor

daniszw commented Feb 4, 2025

The out_channels attribute of the backbones is used to define how to connect the model to necks and decoders. The usual way to connect Prithvi with a segmentation 2D segmentation decoder is through the neck ReshapeTokensToImage, which rearranges the output of the transformer with shape (batch, num_tokens, embed_dim) to be compatible with the decoder convolutional layers that expect something like (batch, channels, H, W).

Since Prithvi is a temporal model (can consume a sequence of images), the output of the model will contain tokens for the entire input sequence, which means that when using the neck ReshapeTokensToImage with correct effective_time_dim, the output of the model will be rearranged to (batch, T*embed_dim, H, W), effectively increasing the number of channels that will be connected to the decoder.

The current solution is to make the out_channels attribute proportional to the effective_time_dim. However, the actual output of the model is (batch, num_tokens, embed_dim) and not (batch, num_tokens, embed_dim * effective_time_dim). The number of channels is only modified after the neck ReshapeTokensToImage rearranges the model output, which I believe makes the neck responsible for modifying the number of channels and not the backbone.

The neck should also allow for different effective_time_dim, as, in theory, the sequence of images could have variable length during inference. Also, it currently assumes the cls_token is the first one to remove, which I’m not sure it is always the case.

@blumenstiel
Copy link
Collaborator

blumenstiel commented Feb 4, 2025

I think these are very important points for multi-temporal data. To put it in some requirements, I think we should:

  1. enable the prediction of multiple masks per time step as currently only one mask is provided (not sure if this case is very reelvant),
  2. enable variable time series lengths as this is might be relevant for inference (e.g. via a neck that applies mean across the time dimension, or via 1. and aggregation over all predicted masks),
  3. automatically infer if a cls_token is provided or not (e.g. check if an error occurs in the forward pass and set cls_token=false if it works),
  4. somehow handle how out_channels are changed by necks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants