Skip to content

Commit

Permalink
Refactor Lightning's trainer.model to trainer.lightning_module (#…
Browse files Browse the repository at this point in the history
…2255)

Refactor trainer.model to trainer.lightning_module
  • Loading branch information
samet-akcay committed Aug 20, 2024
1 parent 2bd2842 commit cfd3d8e
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/anomalib/callbacks/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _should_skip_saving_checkpoint(self, trainer: Trainer) -> bool:
Overrides the parent method to allow saving during both the ``FITTING`` and ``VALIDATING`` states, and to allow
saving when the global step and last_global_step_saved are both 0 (only for zero-/few-shot models).
"""
is_zero_or_few_shot = trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]
is_zero_or_few_shot = trainer.lightning_module.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]
return (
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
or trainer.state.fn not in [TrainerFn.FITTING, TrainerFn.VALIDATING] # don't save anything during non-fit
Expand All @@ -52,7 +52,7 @@ def _should_save_on_train_epoch_end(self, trainer: Trainer) -> bool:
if self._save_on_train_epoch_end is not None:
return self._save_on_train_epoch_end

if trainer.model.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]:
if trainer.lightning_module.learning_type in [LearningType.ZERO_SHOT, LearningType.FEW_SHOT]:
return False

return super()._should_save_on_train_epoch_end(trainer)
8 changes: 4 additions & 4 deletions src/anomalib/data/base/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def train_transform(self) -> Transform:
"""
if self._train_transform:
return self._train_transform
if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform:
return self.trainer.model.transform
if getattr(self, "trainer", None) and self.trainer.lightning_module and self.trainer.lightning_module.transform:
return self.trainer.lightning_module.transform
if self.image_size:
return Resize(self.image_size, antialias=True)
return None
Expand All @@ -284,8 +284,8 @@ def eval_transform(self) -> Transform:
"""
if self._eval_transform:
return self._eval_transform
if getattr(self, "trainer", None) and self.trainer.model and self.trainer.model.transform:
return self.trainer.model.transform
if getattr(self, "trainer", None) and self.trainer.lightning_module and self.trainer.lightning_module.transform:
return self.trainer.lightning_module.transform
if self.image_size:
return Resize(self.image_size, antialias=True)
return None
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def model(self) -> AnomalyModule:
Returns:
AnomalyModule: Anomaly model.
"""
if not self.trainer.model:
if not self.trainer.lightning_module:
msg = "Trainer does not have a model assigned yet."
raise UnassignedError(msg)
return self.trainer.lightning_module
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/metrics/test_adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ def test_manual_threshold() -> None:
devices=1,
)
engine.fit(model=model, datamodule=datamodule)
assert engine.trainer.model.image_metrics.F1Score.threshold == image_threshold
assert engine.trainer.model.pixel_metrics.F1Score.threshold == pixel_threshold
assert engine.trainer.lightning_module.image_metrics.F1Score.threshold == image_threshold
assert engine.trainer.lightning_module.pixel_metrics.F1Score.threshold == pixel_threshold

0 comments on commit cfd3d8e

Please sign in to comment.