Skip to content

Commit

Permalink
Merge branch 'release/1.2-dev' into refactor/legacy-accel-plug
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jan 25, 2021
2 parents a7fffc3 + 30f31d3 commit bf200bf
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ Lightning has a few built-in callbacks.
:nosignatures:
:template: classtemplate.rst

BackboneLambdaFinetuningCallback
BaseFinetuningCallback
Callback
EarlyStopping
GPUStatsMonitor
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _recursive_freeze(module: Module,
def filter_params(module: Module,
train_bn: bool = True) -> Generator:
"""Yields the trainable parameters of a given module.
Args:
module: A given module
train_bn: If True, leave the BatchNorm layers in training mode
Expand All @@ -98,6 +99,7 @@ def filter_params(module: Module,
@staticmethod
def freeze(module: Module, train_bn: bool = True) -> None:
"""Freezes the layers up to index n (if n is not None).
Args:
module: The module to freeze (at least partially)
train_bn: If True, leave the BatchNorm layers in training mode
Expand Down Expand Up @@ -148,6 +150,7 @@ class BackboneLambdaFinetuningCallback(BaseFinetuningCallback):
Finetunne a backbone model based on a learning rate user-defined scheduling.
When the backbone learning rate reaches the current model learning rate
and ``should_align`` is set to True, it will align with it for the rest of the training.
Args:
unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
lambda_func: Scheduling function for increasing backbone learning rate.
Expand All @@ -165,7 +168,9 @@ class BackboneLambdaFinetuningCallback(BaseFinetuningCallback):
reaches it.
verbose: Display current learning rate for model and backbone
round: Precision for displaying learning rate
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneLambdaFinetuningCallback
>>> multiplicative = lambda epoch: 1.5
Expand Down

0 comments on commit bf200bf

Please sign in to comment.