Skip to content

Commit

Permalink
Revert removal of empty-parameters check for configure_optimizers()
Browse files Browse the repository at this point in the history
… with FSDP (#18785)
  • Loading branch information
awaelchli authored Oct 12, 2023
1 parent 20ce3ae commit 6f6c07d
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
13 changes: 12 additions & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,18 @@ def setup(self, trainer: "pl.Trainer") -> None:
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
if self.kwargs.get("use_orig_params"):
return super().setup_optimizers(trainer)
if any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):

invalid_params_error = False
try:
# In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access
# `self.trainer.model.parameters()` in configure_optimizers()
super().setup_optimizers(trainer)
except ValueError as ex:
if "optimizer got an empty parameter list" not in str(ex):
raise
invalid_params_error = True

if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
Expand Down
18 changes: 15 additions & 3 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,16 +359,22 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg):


@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
def test_invalid_parameters_in_optimizer():
@pytest.mark.parametrize("use_orig_params", [None, False, True])
def test_invalid_parameters_in_optimizer(use_orig_params):
fsdp_kwargs = {}
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not None:
fsdp_kwargs = {"use_orig_params": use_orig_params}

trainer = Trainer(
strategy="fsdp",
strategy=FSDPStrategy(**fsdp_kwargs),
accelerator="cuda",
devices=1,
fast_dev_run=1,
)

error_context = (
nullcontext()
if _TORCH_GREATER_EQUAL_2_0
if _TORCH_GREATER_EQUAL_2_0 and (_TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False)
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
)

Expand All @@ -385,6 +391,12 @@ def configure_optimizers(self):
layer = torch.nn.Linear(4, 5)
return torch.optim.Adam(layer.parameters(), lr=1e-2)

error_context = (
nullcontext()
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not False
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
)

model = NoFlatParametersModel()
with error_context:
trainer.fit(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def on_before_optimizer_step(self, optimizer, *_):

def test_step_with_optimizer_closure(tmpdir):
"""Tests that `step` works with optimizer_closure."""
seed_everything(1)

class TestModel(BoringModel):
_losses = []
Expand Down

0 comments on commit 6f6c07d

Please sign in to comment.