Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running inference on prithvi EO v2 models fails with terratorch predict ... #438

Open
WanjiruCate opened this issue Feb 20, 2025 · 6 comments
Assignees
Labels

Comments

@WanjiruCate
Copy link

Describe the issue
Running inference with a finetuned model using the cli fails using an image downloaded using geostudio.

terratorch predict -c  file.yaml --ckpt_path epoch.ckpt --predict_output_dir outputs/ --data.init_args.predict_data_root inputs/

To Reproduce (optional, but appreciated)
Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Screenshots or log output (optional)
If applicable, add screenshots or log output to help explain your problem.

Log Output

(app-root) sh-5.1$ terratorch predict -c /working/sen1floods11/config_deploy.yaml --ckpt_path /working/sen1floods11/968/eaafc82164b842698d44f0e699c1734a/checkpoints/epoch=0.ckpt --predict_output_dir /working/outputs/ --data.init_args.predict_data_root /working/inputs//
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.4 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations
INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
/opt/miniforge/lib/python3.12/site-packages/terratorch/models/decoders/upernet_decoder.py:37: UserWarning: DeprecationWarning: scale_modules is deprecated and will be removed in future versions. Use LearnedInterpolateToPyramidal neck instead.
warnings.warn(
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: Restoring states from the checkpoint path at /working/sen1floods11/968/eaafc82164b842698d44f0e699c1734a/checkpoints/epoch=0.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /working/sen1floods11/968/eaafc82164b842698d44f0e699c1734a/checkpoints/epoch=0.ckpt
/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:277: Be aware that when using ckpt_path, callbacks used to create the checkpoint need to be provided during Trainer instantiation. Please add the following callbacks: ["EarlyStopping{'monitor': 'val/loss', 'mode': 'min'}"].
/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:362: The dirpath has changed from '/working/sen1floods11/968/eaafc82164b842698d44f0e699c1734a/checkpoints' to '/working/checkpoints', therefore best_model_score, kth_best_model_path, kth_value, last_model_path and best_k_models won't be reloaded. Only best_model_path will be reloaded.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /working/sen1floods11/968/eaafc82164b842698d44f0e699c1734a/checkpoints/epoch=0.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /working/sen1floods11/968/eaafc82164b842698d44f0e699c1734a/checkpoints/epoch=0.ckpt
Predicting DataLoader 0: 0% 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
File "/opt/miniforge/lib/python3.12/site-packages/einops/einops.py", line 532, in reduce
return _apply_recipe(
^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/einops/einops.py", line 235, in _apply_recipe
init_shapes, axes_reordering, reduced_axes, added_axes, final_shapes, n_axes_w_added = _reconstruct_from_shape(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/einops/einops.py", line 188, in _reconstruct_from_shape_uncached
raise EinopsError(f"Shape mismatch, can't divide axis of length {length} in chunks of {known_product}")
einops.EinopsError: Shape mismatch, can't divide axis of length 7800 in chunks of 88

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/opt/miniforge/bin/terratorch", line 8, in
sys.exit(main())
^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/terratorch/main.py", line 9, in main
_ = build_lightning_cli()
^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/terratorch/cli_tools.py", line 457, in build_lightning_cli
return MyLightningCLI(
^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 396, in init
self._run_subcommand(self.subcommand)
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 706, in _run_subcommand
fn(**fn_kwargs)
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 859, in predict
return call._call_and_handle_interrupt(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 898, in _predict_impl
results = self._run(model, ckpt_path=ckpt_path)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 982, in _run
results = self._run_stage()
^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1021, in _run_stage
return self.predict_loop.run()
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator
return loop_run(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/loops/prediction_loop.py", line 125, in run
self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/loops/prediction_loop.py", line 255, in _predict_step
predictions = call._call_strategy_hook(trainer, "predict_step", *step_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 323, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 438, in predict_step
return self.lightning_module.predict_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/terratorch/tasks/segmentation_tasks.py", line 348, in predict_step
model_output: ModelOutput = self(x, **rest)
^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/torchgeo/trainers/base.py", line 78, in forward
return self.model(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/terratorch/models/pixel_wise_model.py", line 123, in forward
features = prepare(features)
^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/terratorch/models/necks.py", line 154, in forward
encoded = rearrange(
^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/einops/einops.py", line 600, in rearrange
return reduce(tensor, pattern, reduction="rearrange", **axes_lengths)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/miniforge/lib/python3.12/site-packages/einops/einops.py", line 542, in reduce
raise EinopsError(message + "\n {}".format(e))
einops.EinopsError: Error while processing rearrange-reduction pattern "batch (t h w) e -> batch (t e) h w".
Input tensor shape: torch.Size([1, 7800, 1024]). Additional info: {'batch': 1, 't': 1, 'h': 88}.
Shape mismatch, can't divide axis of length 7800 in chunks of 88

Expected behavior (optional)
Prediction to happen seamlessly

Deployment information (optional)
Describe what you've deployed and how:

  • TerraTorch version: git+https://git@github.com/IBM/terratorch.git@1fa07cd806
  • Installation source: pip install git+https://git@github.com/IBM/terratorch.git@1fa07cd806
@WanjiruCate
Copy link
Author

@paolofraccaro suggested to use this function to resize and pad the image to be ready for the model.

@WanjiruCate
Copy link
Author

With @romeokienzler suggestion.

y cab crop the image in the data loader config
so that u get 7744 tokens
just add a resize transform
Config file
data:
  class_path: terratorch.datamodules.GenericNonGeoSegmentationDataModule
  init_args:
    allow_substring_split_file: true
    batch_size: 4
    constant_scale: 1.0
    dataset_bands:
    - 2
    - 1
    - 0
    - 3
    - 4
    - 5
    drop_last: true
    expand_temporal_dimension: false
    ignore_split_file_extensions: true
    img_grep: '*_S2GeodnHand.tif'
    label_grep: '*_LabelHand.tif'
    means:
    - 0.12520133
    - 0.13471393
    - 0.107582
    - 0.3236181
    - 0.2341743
    - 0.15878009
    no_data_replace: 0.0
    no_label_replace: -1
    num_classes: 2
    num_workers: 2
    output_bands:
    - 2
    - 1
    - 0
    - 3
    - 4
    - 5
    pin_memory: false
    reduce_zero_label: false
    rgb_indices:
    - 2
    - 1
    - 0
    stds:
    - 0.07323416
    - 0.06783548
    - 0.07145836
    - 0.09489725
    - 0.07938496
    - 0.07089546
    test_data_root: /data/geodata-dbce399c854511efb3260a580a830dad/training_data
    test_label_data_root: /data/geodata-dbce399c854511efb3260a580a830dad/labels
    test_split: /data/geodata-dbce399c854511efb3260a580a830dad/split_files/test_data.txt
    test_transform:
    - class_path: albumentations.Resize
      init_args:
        always_apply: false
        height: 512
        interpolation: 1
        p: 1.0
        width: 512
    - class_path: albumentations.pytorch.ToTensorV2
      init_args:
        always_apply: true
        p: 1.0
        transpose_mask: false
    train_data_root: /data/geodata-dbce399c854511efb3260a580a830dad/training_data
    train_label_data_root: /data/geodata-dbce399c854511efb3260a580a830dad/labels
    train_split: /data/geodata-dbce399c854511efb3260a580a830dad/split_files/train_data.txt
    train_transform:
    - class_path: albumentations.Resize
      init_args:
        always_apply: false
        height: 512
        interpolation: 1
        p: 1.0
        width: 512
    - class_path: albumentations.RandomCrop
      init_args:
        always_apply: false
        height: 224
        p: 1.0
        width: 224
    - class_path: albumentations.HorizontalFlip
      init_args:
        always_apply: false
        p: 0.5
    - class_path: albumentations.VerticalFlip
      init_args:
        always_apply: false
        p: 0.5
    - class_path: albumentations.pytorch.ToTensorV2
      init_args:
        always_apply: true
        p: 1.0
        transpose_mask: false
    val_data_root: /data/geodata-dbce399c854511efb3260a580a830dad/training_data
    val_label_data_root: /data/geodata-dbce399c854511efb3260a580a830dad/labels
    val_split: /data/geodata-dbce399c854511efb3260a580a830dad/split_files/val_data.txt
    val_transform:
    - class_path: albumentations.Resize
      init_args:
        always_apply: false
        height: 512
        interpolation: 1
        p: 1.0
        width: 512
    - class_path: albumentations.pytorch.ToTensorV2
      init_args:
        always_apply: true
        p: 1.0
        transpose_mask: false
deploy_config_file: true
model:
  class_path: terratorch.tasks.SemanticSegmentationTask
  init_args:
    freeze_backbone: false
    freeze_decoder: false
    freeze_head: false
    ignore_index: -1
    loss: ce
    lr: 0.001
    model_args:
      backbone: prithvi_eo_v2_300
      backbone_bands:
      - BLUE
      - GREEN
      - RED
      - NIR_NARROW
      - SWIR_1
      - SWIR_2
      backbone_pretrained: false
      decoder: UperNetDecoder
      decoder_channels: 256
      decoder_scale_modules: true
      head_dropout: 0.1
      necks:
      - indices:
        - 5
        - 11
        - 17
        - 23
        name: SelectIndices
      - name: ReshapeTokensToImage
      - name: LearnedInterpolateToPyramidal
      num_classes: 2
      rescale: true
    model_factory: EncoderDecoderFactory
    output_most_probable: true
    plot_on_val: 10
out_dtype: int16
seed_everything: 0
trainer:
  accelerator: auto
  accumulate_grad_batches: 1
  barebones: false
  check_val_every_n_epoch: 1
  detect_anomaly: false
  devices: auto
  enable_checkpointing: true
  fast_dev_run: false
  inference_mode: true
  log_every_n_steps: 10
  logger: false
  max_epochs: 2
  max_steps: -1
  num_nodes: 1
  overfit_batches: 0.0
  precision: 16-mixed
  reload_dataloaders_every_n_epochs: 0
  strategy: auto
  sync_batchnorm: false
  use_distributed_sampler: true

I get this error

(app-root) bash-5.1$ terratorch predict -c /working/sen1floods11/resize_512/config_deploy.yaml --ckpt_path /working/sen1floods11/resize_512/epoch=0.ckpt --predict_output_dir /working/outputs/ --data.init_args.predict_data_root /working/data-cos/ 
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.4 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations
INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
/opt/miniforge/lib/python3.12/site-packages/terratorch/models/decoders/upernet_decoder.py:37: UserWarning: DeprecationWarning: scale_modules is deprecated and will be removed in future versions. Use LearnedInterpolateToPyramidal neck instead.
  warnings.warn(
INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: Restoring states from the checkpoint path at /working/sen1floods11/resize_512/epoch=0.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Restoring states from the checkpoint path at /working/sen1floods11/resize_512/epoch=0.ckpt
/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:277: Be aware that when using `ckpt_path`, callbacks used to create the checkpoint need to be provided during `Trainer` instantiation. Please add the following callbacks: ["EarlyStopping{'monitor': 'val/loss', 'mode': 'min'}"].
/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:362: The dirpath has changed from '/working/sen1floods11/resize_512' to '/working/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: Loaded model weights from the checkpoint at /working/sen1floods11/resize_512/epoch=0.ckpt
INFO:lightning.pytorch.utilities.rank_zero:Loaded model weights from the checkpoint at /working/sen1floods11/resize_512/epoch=0.ckpt
Predicting DataLoader 0:   0%|                                                                                                                                                                                                 | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/opt/miniforge/bin/terratorch", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/opt/miniforge/lib/python3.12/site-packages/terratorch/__main__.py", line 9, in main
    _ = build_lightning_cli()
        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniforge/lib/python3.12/site-packages/terratorch/cli_tools.py", line 457, in build_lightning_cli
    return MyLightningCLI(
           ^^^^^^^^^^^^^^^
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 396, in __init__
    self._run_subcommand(self.subcommand)
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/cli.py", line 706, in _run_subcommand
    fn(**fn_kwargs)
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 859, in predict
    return call._call_and_handle_interrupt(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 898, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 982, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1021, in _run_stage
    return self.predict_loop.run()
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/loops/prediction_loop.py", line 125, in run
    self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/loops/prediction_loop.py", line 268, in _predict_step
    call._call_callback_hooks(trainer, "on_predict_batch_end", predictions, *hook_kwargs.values())
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 222, in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
  File "/opt/miniforge/lib/python3.12/site-packages/lightning/pytorch/callbacks/prediction_writer.py", line 156, in on_predict_batch_end
    self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)
  File "/opt/miniforge/lib/python3.12/site-packages/terratorch/cli_tools.py", line 173, in write_on_batch_end
    save_prediction(prediction, file_name, output_dir, dtype=trainer.out_dtype)
  File "/opt/miniforge/lib/python3.12/site-packages/terratorch/cli_tools.py", line 102, in save_prediction
    result = np.where(mask == 1, -1, prediction.detach().cpu())
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: operands could not be broadcast together with shapes (1278,2325) () (512,512) 

Box folder with artifacts: https://ibm.box.com/s/4nrf2jsfskjp5sbjg0ltrm5mh62f0pks

@WanjiruCate
Copy link
Author

WanjiruCate commented Feb 20, 2025

With @paolofraccaro recomendation I get this error:

Traceback (most recent call last):
  File "/working/test/paolo_inference.py", line 354, in <module>
    main(**vars(args))
  File "/working/test/paolo_inference.py", line 269, in main
    pred = run_model(input_data, temporal_coords, location_coords,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/working/test/paolo_inference.py", line 189, in run_model
    x = datamodule.aug(x)['image']
        ^^^^^^^^^^^^^^^^^
  File "/opt/miniforge/lib/python3.12/site-packages/terratorch/datamodules/generic_pixel_wise_data_module.py", line 66, in __call__
    raise Exception(msg)
Exception: Expected batch to have 5 or 4 dimensions, but got 3

@WanjiruCate
Copy link
Author

WanjiruCate commented Feb 20, 2025

Is there an updated inference.py file for models without temporal and location ? This was shared https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py but i get the error above on dimensions.

I have checked the floods and burnscars hugging face repos, but nothing is there.

@Joao-L-S-Almeida Joao-L-S-Almeida self-assigned this Feb 20, 2025
@Joao-L-S-Almeida
Copy link
Member

HI, @WanjiruCate I saw the issue happens in the function save_prediction from terratorch/cli_tools.py. When terratorch tries to save the predicted image, it performs a checking which considers that the input/output images have the same size:

def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"):
    mask, metadata = open_tiff(input_file_name)
    mask = np.where(mask == metadata["nodata"], 1, 0)
    mask = np.max(mask, axis=0)
    result = np.where(mask == 1, -1, prediction.detach().cpu())

I understand that it isn't always true.

@Joao-L-S-Almeida
Copy link
Member

Joao-L-S-Almeida commented Feb 24, 2025

@WanjiruCate
When I modify this function as:

def save_prediction(prediction, input_file_name, out_dir, dtype:str="int16"):
    mask, metadata = open_tiff(input_file_name)
    mask = np.where(mask == metadata["nodata"], 1, 0)
    mask = np.max(mask, axis=0)
    #result = np.where(mask == 1, -1, prediction.detach().cpu())
    result = prediction.detach().cpu()

It runs, but I got a black image as output.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants