Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Updated logic to update the new model versions value backbone_pretrained to false #418

Merged
merged 1 commit into from
Feb 12, 2025

Conversation

WanjiruCate
Copy link

After Finetuning a model, the config_deploy.yaml generated for inferencing usually updates some values like the pretrained to false.

For the new prithvi EO v2 models, this value was changed to backbone_pretrained and when trying to do an inference with this config, an error with the previous pretrained:false value is thrown.

In [5]: model = LightningInferenceModel.from_config(
   ...:     config_path=config_path,
   ...:     checkpoint_path=weights_path,
   ...:     # predict_dataset_bands=predict_dataset_bands,
   ...: )
INFO: Seed set to 0
INFO:lightning.fabric.utilities.seed:Seed set to 0
INFO:root:Loaded weights for HLSBands.BLUE in position 0 of patch embed
INFO:root:Loaded weights for HLSBands.GREEN in position 1 of patch embed
INFO:root:Loaded weights for HLSBands.RED in position 2 of patch embed
INFO:root:Loaded weights for HLSBands.NIR_NARROW in position 3 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_1 in position 4 of patch embed
INFO:root:Loaded weights for HLSBands.SWIR_2 in position 5 of patch embed
/Users/catherinewanjiru/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/terratorch/models/decoders/upernet_decoder.py:37: UserWarning: DeprecationWarning: scale_modules is deprecated and will be removed in future versions. Use LearnedInterpolateToPyramidal neck instead.
  warnings.warn(
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/jsonargparse/_typehints.py:821, in adapt_typehints(val, typehint, serialize, instantiate_classes, prev_val, append, list_item, enable_path, sub_add_kwargs, default, logger)
    820 try:
--> 821     vals.append(adapt_typehints(val, subtypehint, **adapt_kwargs))
    822     break

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/jsonargparse/_typehints.py:1073, in adapt_typehints(val, typehint, serialize, instantiate_classes, prev_val, append, list_item, enable_path, sub_add_kwargs, default, logger)
   1072     val["class_path"] = get_import_path(val_class)
-> 1073     val = adapt_class_type(val, serialize, instantiate_classes, sub_add_kwargs, prev_val=prev_val)
   1074 except (ImportError, AttributeError, AssertionError, ArgumentError) as ex:

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/jsonargparse/_typehints.py:1392, in adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev_val, skip_args, partial_classes)
   1391         return partial_instance
-> 1392     return instantiator_fn(val_class, **{**init_args, **dict_kwargs})
   1394 prev_init_args = prev_val.get("init_args") if isinstance(prev_val, Namespace) else None

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/jsonargparse/_common.py:165, in ClassInstantiator.__call__(self, class_type, *args, **kwargs)
    164     if class_type is cls or (subclasses and is_subclass(class_type, cls)):
--> 165         return instantiator(class_type, *args, **kwargs)
    166 return default_class_instantiator(class_type, *args, **kwargs)

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/lightning/pytorch/cli.py:808, in _InstantiatorFn.__call__(self, class_type, *args, **kwargs)
    804 with _given_hyperparameters_context(
    805     hparams=hparams,
    806     instantiator="lightning.pytorch.cli.instantiate_module",
    807 ):
--> 808     return class_type(*args, **kwargs)

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/terratorch/tasks/segmentation_tasks.py:128, in SemanticSegmentationTask.__init__(self, model_args, model_factory, model, loss, aux_heads, aux_loss, class_weights, ignore_index, lr, optimizer, optimizer_hparams, scheduler, scheduler_hparams, freeze_backbone, freeze_decoder, freeze_head, plot_on_val, class_names, tiled_inference_parameters, test_dataloaders_names, lr_overrides)
    126     self.model_factory = MODEL_FACTORY_REGISTRY.build(model_factory)
--> 128 super().__init__(task="segmentation")
    130 if model is not None:
    131     # Custom model

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/terratorch/tasks/base_task.py:24, in TerraTorchTask.__init__(self, task)
     22 self.task = task
---> 24 super().__init__()

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/torchgeo/trainers/base.py:39, in BaseTask.__init__(self, ignore)
     38 self.save_hyperparameters(ignore=ignore)
---> 39 self.configure_models()
     40 self.configure_losses()

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/terratorch/tasks/base_task.py:37, in TerraTorchTask.configure_models(self)
     35     return
---> 37 self.model: Model = self.model_factory.build_model(
     38     self.task, aux_decoders=self.aux_heads, **self.hparams["model_args"]
     39 )
     41 if self.hparams["freeze_backbone"]:

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/terratorch/models/encoder_decoder_factory.py:176, in EncoderDecoderFactory.build_model(self, task, backbone, decoder, num_classes, necks, aux_decoders, rescale, peft_config, **kwargs)
    175 if aux_decoders is None:
--> 176     _check_all_args_used(kwargs)
    177     return _build_appropriate_model(
    178         task,
    179         backbone,
   (...)
    186         rescale=rescale,
    187     )

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/terratorch/models/encoder_decoder_factory.py:66, in _check_all_args_used(kwargs)
     65 msg = f"arguments {kwargs} were passed but not used."
---> 66 raise ValueError(msg)

ValueError: arguments {'pretrained': False} were passed but not used.

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[5], line 1
----> 1 model = LightningInferenceModel.from_config(
      2     config_path=config_path,
      3     checkpoint_path=weights_path,
      4     # predict_dataset_bands=predict_dataset_bands,
      5 )

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/terratorch/cli_tools.py:528, in LightningInferenceModel.from_config(config_path, checkpoint_path, predict_dataset_bands, predict_output_bands)
    524 if predict_output_bands is not None:
    525     arguments.extend([ "--data.init_args.predict_output_bands",
    526     "[" + ",".join(predict_output_bands) + "]",])
--> 528 cli = build_lightning_cli(arguments, run=False)
    529 trainer = cli.trainer
    530 # disable logging metrics

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/terratorch/cli_tools.py:447, in build_lightning_cli(args, run)
    439     else:
    440         warnings.warn(
    441             "Found terratorch_FLOAT_32_PRECISION env variable but value was set to precision.\
    442             Set to one of {allowed_values}. Will be ignored this run.",
    443             UserWarning,
    444             stacklevel=1,
    445         )
--> 447 return MyLightningCLI(
    448     model_class=BaseTask,
    449     subclass_mode_model=True,
    450     subclass_mode_data=True,
    451     seed_everything_default=0,
    452     save_config_callback=StudioDeploySaveConfigCallback if run else None,
    453     save_config_kwargs={"overwrite": True},
    454     args=args,
    455     # save only state_dict as well as full state. Only state_dict will be used for exporting the model
    456     trainer_defaults={"callbacks": [CustomWriter(write_interval="batch")]},
    457     run=run,
    458     trainer_class=MyTrainer,
    459 )

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/lightning/pytorch/cli.py:391, in LightningCLI.__init__(self, model_class, datamodule_class, save_config_callback, save_config_kwargs, trainer_class, trainer_defaults, seed_everything_default, parser_kwargs, subclass_mode_model, subclass_mode_data, args, run, auto_configure_optimizers)
    389 self._add_instantiators()
    390 self.before_instantiate_classes()
--> 391 self.instantiate_classes()
    393 if self.subcommand is not None:
    394     self._run_subcommand(self.subcommand)

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/terratorch/cli_tools.py:384, in MyLightningCLI.instantiate_classes(self)
    382 def instantiate_classes(self) -> None:
--> 384     super().instantiate_classes()
    385     # get the predict_output_dir. Depending on the value of run, it may be in the subcommand
    386     try:

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/lightning/pytorch/cli.py:557, in LightningCLI.instantiate_classes(self)
    555 def instantiate_classes(self) -> None:
    556     """Instantiates the classes and sets their attributes."""
--> 557     self.config_init = self.parser.instantiate_classes(self.config)
    558     self.datamodule = self._get(self.config_init, "data")
    559     self.model = self._get(self.config_init, "model")

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/jsonargparse/_deprecated.py:140, in parse_as_dict_patch.<locals>.patched_instantiate_classes(self, cfg, **kwargs)
    138 if isinstance(cfg, dict):
    139     cfg = self._apply_actions(cfg)
--> 140 cfg = self._unpatched_instantiate_classes(cfg, **kwargs)
    141 return cfg.as_dict() if self._parse_as_dict else cfg

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/jsonargparse/_core.py:1204, in ArgumentParser.instantiate_classes(self, cfg, instantiate_groups)
   1198         if value is not None:
   1199             with parser_context(
   1200                 parent_parser=self,
   1201                 nested_links=ActionLink.get_nested_links(self, component),
   1202                 class_instantiators=self._get_instantiators(),
   1203             ):
-> 1204                 parent[key] = component.instantiate_classes(value)
   1205 else:
   1206     with parser_context(load_value_mode=self.parser_mode, class_instantiators=self._get_instantiators()):

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/jsonargparse/_typehints.py:611, in ActionTypeHint.instantiate_classes(self, value)
    609 sub_add_kwargs = getattr(self, "sub_add_kwargs", {})
    610 for num, val in enumerate(value):
--> 611     value[num] = adapt_typehints(
    612         val,
    613         self._typehint,
    614         default=self.default,
    615         instantiate_classes=True,
    616         sub_add_kwargs=sub_add_kwargs,
    617         logger=self.logger,
    618     )
    619 return value if islist else value[0]

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/jsonargparse/_typehints.py:826, in adapt_typehints(val, typehint, serialize, instantiate_classes, prev_val, append, list_item, enable_path, sub_add_kwargs, default, logger)
    824             vals.append(ex)
    825     if all(isinstance(v, Exception) for v in vals):
--> 826         raise_union_unexpected_value(typehint, val, vals)
    827     val = [v for v in vals if not isinstance(v, Exception)][0]
    829 # Tuple or Set

File ~/miniforge3/envs/prithviEOV2/lib/python3.11/site-packages/jsonargparse/_typehints.py:710, in raise_union_unexpected_value(uniontype, val, exceptions)
    708 errors = errors.replace(f". Got value: {val}", "").replace(f" {val} ", " ")
    709 subtypes = uniontype.__args__
--> 710 raise ValueError(
    711     f"Does not validate against any of the Union subtypes\nSubtypes: {subtypes}"
    712     f"\nErrors:\n{errors}\nGiven value type: {type(val)}\nGiven value: {val}"
    713 ) from exceptions[0]

ValueError: Does not validate against any of the Union subtypes
Subtypes: (<class 'torchgeo.trainers.base.BaseTask'>, <class 'NoneType'>)
Errors:
  - arguments {'pretrained': False} were passed but not used.
  - Expected a <class 'NoneType'>
Given value type: <class 'jsonargparse._namespace.Namespace'>
Given value: Namespace(class_path='terratorch.tasks.SemanticSegmentationTask', init_args=Namespace(model_args={'backbone': 'prithvi_eo_v2_300', 'backbone_bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'], 'backbone_pretrained': True, 'decoder': 'UperNetDecoder', 'decoder_channels': 256, 'decoder_scale_modules': True, 'head_dropout': 0.1, 'necks': [{'indices': [5, 11, 17, 23], 'name': 'SelectIndices'}, {'name': 'ReshapeTokensToImage'}], 'num_classes': 2, 'pretrained': False, 'rescale': True}, model_factory='EncoderDecoderFactory', model=None, loss='ce', aux_heads=None, aux_loss=None, class_weights=None, ignore_index=-1, lr=0.001, optimizer=None, optimizer_hparams=None, scheduler=None, scheduler_hparams=None, freeze_backbone=False, freeze_decoder=False, freeze_head=False, plot_on_val=10, class_names=None, tiled_inference_parameters=None, test_dataloaders_names=None, lr_overrides=None))

Copy link
Collaborator

@romeokienzler romeokienzler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WanjiruCate lgtm, thanks a lot

  • merged

@romeokienzler romeokienzler merged commit 1fa07cd into IBM:main Feb 12, 2025
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants