From 137c7b9d7443c32e21db8e5e9f70c919141984f3 Mon Sep 17 00:00:00 2001 From: Corwin Joy Date: Mon, 1 Jul 2024 17:51:58 -0700 Subject: [PATCH 1/2] Disable _optimizer_to_device logic --- src/lightning/fabric/utilities/optimizer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/lightning/fabric/utilities/optimizer.py b/src/lightning/fabric/utilities/optimizer.py index e2605ceca4670..c5a9df39a44b6 100644 --- a/src/lightning/fabric/utilities/optimizer.py +++ b/src/lightning/fabric/utilities/optimizer.py @@ -14,11 +14,8 @@ from typing import Iterable -from lightning_utilities.core.apply_func import apply_to_collection -from torch import Tensor from torch.optim import Optimizer -from lightning.fabric.utilities.apply_func import move_data_to_device from lightning.fabric.utilities.types import _DEVICE @@ -30,5 +27,4 @@ def _optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> N def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None: """Moves the state of a single optimizer to the device.""" - for p, v in optimizer.state.items(): - optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True) + pass From d2589fc817fe0c413e18c71eb52da47b0eed1015 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 3 Jul 2024 02:07:14 +0200 Subject: [PATCH 2/2] update test --- tests/tests_pytorch/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 802c1a17bc448..f5e90fdabf944 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1766,7 +1766,7 @@ def current_memory(): trainer.fit(model) assert trainer.strategy.model is model - assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu") + assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cuda", 0) assert trainer.callback_metrics["train_loss"].device == torch.device("cpu") assert current_memory() <= initial