Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stepwise LR scheduler #20211

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5dba6f9
Fix DDP strategy registration with override
01AbhiSingh Jul 23, 2024
3d8b2bf
added ddp alias strategy in strategies/ddp.py
01AbhiSingh Jul 24, 2024
7a55c5c
added ddp alias strategy in strategies/ddp.py
01AbhiSingh Jul 24, 2024
f4b01e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2024
4424d70
Merge branch 'master' into ddp-strategy-alias
01AbhiSingh Jul 27, 2024
607363e
updated tests
01AbhiSingh Aug 6, 2024
3099586
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 6, 2024
935a9c1
updated test_registry.py
01AbhiSingh Aug 6, 2024
c70ef61
Merge branch 'master' into ddp-strategy-alias
01AbhiSingh Aug 6, 2024
ebfedf6
updated test_cli.py
01AbhiSingh Aug 6, 2024
3285d7a
Merge branch 'ddp-strategy-alias' of https://github.com/01AbhiSingh/p…
01AbhiSingh Aug 6, 2024
4b7b719
Stepwise LR scheduler not working across epochs
01AbhiSingh Aug 16, 2024
5be642f
Merge remote-tracking branch 'origin' into Stepwise-LR-scheduler
01AbhiSingh Aug 16, 2024
fc01630
Merge branch 'master' into stepwiseLRscheduler
01AbhiSingh Aug 21, 2024
7f748cf
Merge branch 'master' into stepwiseLRscheduler
01AbhiSingh Aug 28, 2024
06f0a0a
Merge branch 'master' into stepwiseLRscheduler
01AbhiSingh Sep 9, 2024
3c48c9e
Merge branch 'master' into stepwiseLRscheduler
Borda Sep 27, 2024
63cd1f0
Added test for LR scheduler stepping across epoch boundaries
01AbhiSingh Dec 7, 2024
48a7c8e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2024
64ed819
added the required changes
01AbhiSingh Dec 11, 2024
29af194
added the required changes
01AbhiSingh Dec 11, 2024
09bc52b
added the required changes
01AbhiSingh Dec 11, 2024
e96c474
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2024
2391336
added the required changes
01AbhiSingh Dec 12, 2024
a273722
Merge branch 'stepwiseLRscheduler' of https://github.com/01AbhiSingh/…
01AbhiSingh Dec 12, 2024
eb98dce
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2024
e45a8f9
added the dataloader function and added the following lib from torch.…
01AbhiSingh Dec 12, 2024
7adad14
Merge branch 'stepwiseLRscheduler' of https://github.com/01AbhiSingh/…
01AbhiSingh Dec 12, 2024
7bb9697
Merge branch 'master' into stepwiseLRscheduler
01AbhiSingh Dec 12, 2024
4c77cb3
Merge branch 'master' into stepwiseLRscheduler
01AbhiSingh Dec 12, 2024
15052fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2024
e30a504
added the changes
01AbhiSingh Dec 13, 2024
9dbbc8d
Merge branch 'stepwiseLRscheduler' of https://github.com/01AbhiSingh/…
01AbhiSingh Dec 13, 2024
27047bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def _update_learning_rates(self, interval: str, update_plateau_schedulers: bool)
if update_plateau_schedulers ^ config.reduce_on_plateau:
continue

current_idx = self.batch_idx if interval == "step" else trainer.current_epoch
current_idx = self.total_batch_idx if interval == "step" else trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
Expand Down
55 changes: 53 additions & 2 deletions tests/tests_pytorch/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import call
from unittest.mock import call, patch

import pytest
import torch
from lightning.pytorch import Trainer
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.core.optimizer import (
_configure_optimizers,
Expand All @@ -27,6 +27,7 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import LRSchedulerConfig
from torch import optim
from torch.utils.data import DataLoader, TensorDataset

from tests_pytorch.helpers.runif import RunIf

Expand Down Expand Up @@ -657,3 +658,53 @@ def lr_scheduler_step(*_): ...
else:
with pytest.raises(MisconfigurationException, match="CustomScheduler` doesn't follow"):
_init_optimizers_and_lr_schedulers(model)


@patch("torch.optim.lr_scheduler.StepLR.step")
def test_lr_scheduler_step_across_epoch_boundaries(mocked_sched, tmp_path):
class StepAcrossEpochsModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)

def forward(self, x):
return self.layer(x)

def training_step(self, batch, batch_idx):
return {"loss": torch.tensor(0.1, requires_grad=True)}

def train_dataloader(self):
x = torch.randn(21, 32)
y = torch.randn(21, 2)
return DataLoader(TensorDataset(x, y), batch_size=3)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
"frequency": 5, # Scheduler steps every 5 iterations
},
}

model = StepAcrossEpochsModel()

# Trainer configuration for cross-epoch testing
trainer = Trainer(
default_root_dir=tmp_path,
limit_train_batches=7, # More than `frequency` iterations per epoch
max_epochs=3, # Test across multiple epochs
)

# Fit the model
trainer.fit(model)

# Calculate the total number of steps (iterations) and expected scheduler calls
total_steps = 7 * 3 # Total iterations (7 batches per epoch * 3 epochs)
expected_steps = (total_steps - 1) // 5 # Scheduler steps every 5 iterations

# Assert that the scheduler was called the expected number of times
assert mocked_sched.call_count == expected_steps
Loading