Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

configure_optimizers with OneCycleLR and Pretrain Freeze/Unfreeze #1120

Closed
0x6b756d6172 opened this issue Mar 11, 2020 · 10 comments
Closed

configure_optimizers with OneCycleLR and Pretrain Freeze/Unfreeze #1120

0x6b756d6172 opened this issue Mar 11, 2020 · 10 comments
Labels
question Further information is requested won't fix This will not be worked on

Comments

@0x6b756d6172
Copy link

0x6b756d6172 commented Mar 11, 2020

Hello. Thanks for the work on this framework - it's something I've been looking for and I am currently working on transition all my own work from fast.ai to pytorch-lightining. I'm currently stuck on the configure_optimizers step.

For those not familiar, the core workflow of fast.ai goes something like this:

#create model with frozen pretrained resnet backbone and untrained linear head
model = MyResnetBasedModel()
learner = Learner(model, ...)

#train the head
learner.fit_one_cycle(5)

#unfreeze pretrained layers and train whole model
learner.unfreeze()
learner.fit_one_cycle(5)

fast.ai uses it's own system for implementing the OneCycleScheduler and it's not the most transparent system. PyTorch has an implementation of the OneCycleScheduler which their documentation illustrates as follows:

data_loader = torch.utils.data.DataLoader(...)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)

Note that OneCycleLR needs to know the total number of steps (or steps per epoch + epochs, from which it determines total steps) in order to generate the correct schedule. configure_optimizers does not appear to offer a way of accessing the necessary values to initialize OneCycleLR, as in my code below.

def configure_optimizers(self):
    optimzer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimzer, self.hparams.lr, ???) #<---
    return optimzer, scheduler

Additionally, it's unclear how the fast.ai flow of freeze, train, unfreeze, train work with Lightning as it appears that configure_optimizers is called once internally by the trainer. It appears it may be possible to train frozen, checkpoint, load and unfreeze but this does add some extra code overhead.

How can I arrange my code to use OneCycleLR with pretrained freezing/unfreezing? Any guidance on how to approach this would be appreciated.

Thanks.

@0x6b756d6172 0x6b756d6172 added the question Further information is requested label Mar 11, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@fabiocapsouza
Copy link

I'm pretty new to pytorch-lightning but I'm also struggling with this issue to use BERT's LR schedule, which as far as I understand is exactly OneCycleLR, but without freeze/unfreeeze.

It seems to me that the freeze/unfreeze can be done with a custom Callback and on_epoch_end.
As for the total number of steps, I think you have to pre-calculate it when instantiating the model or call len(self.model.train_dataloader()) inside self.configure_optimizers.

@0x6b756d6172
Copy link
Author

0x6b756d6172 commented Mar 12, 2020

Possibly related issues: #1038 #941.

@fabiocapsouza thanks, taking your input, I'm setting up my code as below, where I pass in the epoch via hparams and the steps_per_epoch via len(self.train_dataloader()) which I think should work once everything is in place. Update: this calls the train_dataloader() function which is called again after the configure_optimizers step based on the lifecycle in the documentation. It seems like this double call should be avoidable, especially since train_dataloader() could have heavy computation.

Additionally, OneCycleLR needs to be updated on every batch and it appears the default is to step the lr scheduler every epoch, rather than batch. I believe the return needs to look something like this based on #941 but I am not sure - the documentation isn't clear on this.

def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, self.hparams.lr, steps_per_epoch=len(self.train_dataloader), epochs=self.hparams.epochs)
        scheduler = {"scheduler": scheduler, "interval" : "step" } #<---???
        return [optimizer], [scheduler]

hparams = ...
hparams.epochs = 3

model = PLModel(hparams)
trainer = pl.Trainer(gpus=1, max_epochs=hparams.epochs)

@SkafteNicki
Copy link
Member

@wxc-kumar your code is correct, to make lightning call .step() method of the scheduler after each batch the "interval" keyword need to be set to "step". Maybe the documentation could be more clear on this.

@stale
Copy link

stale bot commented May 12, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label May 12, 2020
@stale stale bot closed this as completed May 21, 2020
@RafailFridman
Copy link

An important moment on setting OneCycleLR parameters:
If you set the number of epochs with steps per epoch parameter, don't forget to take into account gradient accumulation like this:

steps_per_epoch = (train_loader_len//self.batch_size)//self.trainer.accumulate_grad_batches

@cowwoc
Copy link
Contributor

cowwoc commented Sep 3, 2021

It doesn't look like the documentation was ever updated to mention that "interval" must be set to "step". It is very hard for new users to figure this out by themselves.

@tchaton
Copy link
Contributor

tchaton commented Sep 3, 2021

Hey @cowwoc,

Mind making a contribution to improve the documentation.
And here is the part where there is multiple examples to see the interval to step: https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers

Best,
T.C

@cowwoc
Copy link
Contributor

cowwoc commented Sep 3, 2021

You are right, my mistake.

@sachinruk
Copy link

An important moment on setting OneCycleLR parameters: If you set the number of epochs with steps per epoch parameter, don't forget to take into account gradient accumulation like this:

steps_per_epoch = (train_loader_len//self.batch_size)//self.trainer.accumulate_grad_batches

I know this is an old thread, but did you mean to say:

steps_per_epoch = len(train_loader) // self.trainer.accumulate_grad_batches

Not sure how batch size is relevant here since the update step is per batch anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

7 participants