diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f750c8aff7caf..6a2c51d6fc104 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -112,6 +112,7 @@ def __init__(self, *args, **kwargs): self._current_hook_fx_name = None self._current_dataloader_idx = None self._automatic_optimization: bool = True + self.param_grad_dict = {} def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -1171,11 +1172,13 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): optimizer_idx: """ for param in self.parameters(): + if param not in self.param_grad_dict: + self.param_grad_dict[param] = param.requires_grad param.requires_grad = False for group in optimizer.param_groups: for param in group['params']: - param.requires_grad = True + param.requires_grad = self.param_grad_dict[param] def optimizer_step( self,