diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 18dc8891144ad..8d085df8f9628 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -5,6 +5,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased] - YYYY-MM-DD + +### Added + +### Changed + +- Merging of hparams when logging now ignores parameter names that begin with underscore `_` ([#20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221)) + +### Removed + +### Fixed + +- Fix LightningCLI failing when both module and data module save hyperparameters due to conflicting internal `_class_path` parameter ([#20221](https://github.com/Lightning-AI/pytorch-lightning/pull/20221)) + + ## [2.4.0] - 2024-08-06 ### Added diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index c763071af5644..e1752c67d9183 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -69,6 +69,9 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: lightning_hparams = pl_module.hparams_initial inconsistent_keys = [] for key in lightning_hparams.keys() & datamodule_hparams.keys(): + if key == "_class_path": + # Skip LightningCLI's internal hparam + continue lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] if ( type(lm_val) != type(dm_val) @@ -88,6 +91,10 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: elif datamodule_log_hyperparams: hparams_initial = trainer.datamodule.hparams_initial + # Don't log LightningCLI's internal hparam + if hparams_initial is not None: + hparams_initial = {k: v for k, v in hparams_initial.items() if k != "_class_path"} + for logger in trainer.loggers: if hparams_initial is not None: logger.log_hyperparams(hparams_initial) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index d106a05e80ec9..f60eadbf38898 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -973,6 +973,29 @@ def test_lightning_cli_save_hyperparameters_untyped_module(cleandir): assert model.kwargs == {"x": 1} +class TestDataSaveHparams(BoringDataModule): + def __init__(self, batch_size: int = 32, num_workers: int = 4): + super().__init__() + self.save_hyperparameters() + self.batch_size = batch_size + self.num_workers = num_workers + + +def test_lightning_cli_save_hyperparameters_merge(cleandir): + config = { + "model": { + "class_path": f"{__name__}.TestModelSaveHparams", + }, + "data": { + "class_path": f"{__name__}.TestDataSaveHparams", + }, + } + with mock.patch("sys.argv", ["any.py", "fit", f"--config={json.dumps(config)}", "--trainer.max_epochs=1"]): + cli = LightningCLI(auto_configure_optimizers=False) + assert set(cli.model.hparams) == {"optimizer", "scheduler", "activation", "_instantiator", "_class_path"} + assert set(cli.datamodule.hparams) == {"batch_size", "num_workers", "_instantiator", "_class_path"} + + @pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn]) def test_lightning_cli_trainer_fn(fn): class TestCLI(LightningCLI):