Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/5405_logging_with_accumulated_gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Jan 8, 2021
2 parents 702df33 + f2e99d6 commit 20b1abf
Show file tree
Hide file tree
Showing 43 changed files with 218 additions and 270 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed depreceated `enable_pl_optimizer=True` ([#5244](https://github.com/PyTorchLightning/pytorch-lightning/pull/5244))


### Deprecated

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
```python
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx, opt_idx):
(opt_a, opt_b) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)

loss_a = ...
self.manual_backward(loss_a, opt_a)
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def train_dataloader(self):
class SeedTrainLoaderManualModel(SeedTrainLoaderModel):
def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)
loss_1 = self.step(batch)

self.manual_backward(loss_1, opt_a)
Expand Down
3 changes: 2 additions & 1 deletion docs/source/new-project.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ Now you own the train loop!
.. code-block:: python
def training_step(self, batch, batch_idx, opt_idx):
(opt_a, opt_b, opt_c) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b, opt_c) = self.optimizers(use_pl_optimizer=True)
loss_a = self.generator(batch[0])
Expand Down
31 changes: 24 additions & 7 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ to manually manage the optimization process. To do so, do the following:
.. code-block:: python
def training_step(self, batch, batch_idx, optimizer_idx):
# ignore optimizer_idx
(opt_g, opt_d) = self.optimizers()
# 1. ignore optimizer_idx
# 2. `use_pl_optimizer=True` means `opt_g` and `opt_d` will be of type `LightingOptimizer`
# `LightingOptimizer` simply wrapped your optimizer and behave the same way !
# When calling `optimizer.step`, `LightingOptimizer` will just handle TPU, AMP, accumulate_grad_batches, etc ... for you.
# access your optimizers with `use_pl_optimizer=False` or `optimizer.optimizer` when using use_pl_optimizer=True
# use_pl_optimizer=True is the default
(opt_g, opt_d) = self.optimizers(use_pl_optimizer=True)
# do anything you want
loss_a = ...
Expand Down Expand Up @@ -242,19 +249,29 @@ Here we add a learning-rate warm up
# update params
optimizer.step(closure=closure)

The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step.
.. note:: The default ``optimizer_step`` is relying on the internal ``LightningOptimizer`` to properly perform a step. It handles TPUs, AMP, accumulate_grad_batches, zero_grad, and much more ...

.. testcode::

from pytorch_lightning.core.optimizer import LightningOptimizer
# function hook in LightningModule
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
optimizer.step(closure=closure)

.. note:: To access your wrapped Optimizer from ``LightningOptimizer``, do as follow.

.. testcode::

# function hook in LightningModule
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)

# `optimizer is a ``LightningOptimizer`` wrapping the optimizer.
# To access it, do as follow:
optimizer = optimizer.optimizer

# run step. However, it won't work on TPU, AMP, etc...
optimizer.step(closure=closure)


----------

Using the closure functions for optimization
Expand Down
6 changes: 4 additions & 2 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ optimizer behavior
Example::
def training_step(self, batch, batch_idx):
opt = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
opt = self.optimizers(use_pl_optimizer=True)
loss = ...
self.manual_backward(loss, opt)
Expand All @@ -350,7 +351,8 @@ In the multi-optimizer case, ignore the optimizer_idx flag and use the optimizer
Example::
def training_step(self, batch, batch_idx, optimizer_idx):
(opt_a, opt_b) = self.optimizers()
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)
gen_loss = ...
self.manual_backward(gen_loss, opt_a)
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union, Callable
from typing import Any, Callable, Optional, Union

import torch

Expand Down Expand Up @@ -48,8 +48,6 @@ def setup(self, model):
# allow for lr schedulers as well
self.setup_optimizers(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def train(self):
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@ def ddp_train(self, process_idx, mp_queue, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
from os.path import abspath
import subprocess
import sys
from os.path import abspath
from time import sleep
from typing import Any, List, Optional, Union

Expand Down Expand Up @@ -291,8 +291,6 @@ def ddp_train(self, process_idx, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ def ddp_train(self, process_idx, mp_queue, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# DDP spawn already spawned off each process... no need to do anything
device_ids = self.get_device_ids()

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
Expand Down Expand Up @@ -183,8 +183,6 @@ def ddp_train(self, process_idx, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ def setup(self, model):
if self.trainer.amp_backend:
model = self.__init_half_precision(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def __init_torch_data_parallel(self, model):
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def setup(self, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def train(self):
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Any, Optional, Union, Callable
from typing import Any, Callable, Optional, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only

if HOROVOD_AVAILABLE:
Expand Down Expand Up @@ -91,8 +91,6 @@ def _filter_named_parameters(model, optimizer):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.trainer.global_rank = hvd.rank()
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from pytorch_lightning.core import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import (
TPU_AVAILABLE,
move_data_to_device,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
TPU_AVAILABLE,
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -230,8 +230,6 @@ def __setup_tpu_training(self, model: LightningModule, trainer):
f' global rank: {trainer.tpu_global_core_rank}'
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')

self.trainer.convert_to_lightning_optimizers()

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
# do backward pass
if self.trainer.train_loop.automatic_optimization:
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,11 @@ def __init__(self, *args, **kwargs):
self._current_hook_fx_name = None
self._current_dataloader_idx = None

def optimizers(self):
opts = self.trainer.optimizers
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
if use_pl_optimizer:
opts = list(self.trainer.lightning_optimizers.values())
else:
opts = self.trainer.optimizers

# single optimizer
if isinstance(opts, list) and len(opts) == 1 and isinstance(opts[0], Optimizer):
Expand Down
20 changes: 13 additions & 7 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from torch.optim.optimizer import Optimizer

from pytorch_lightning.utilities import TPU_AVAILABLE
from pytorch_lightning.utilities import AMPType, TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if TPU_AVAILABLE:
Expand Down Expand Up @@ -62,6 +62,10 @@ def __init__(self,
self._accumulate_grad_batches = accumulate_grad_batches
self._optimizer_idx = None

@property
def optimizer(self):
return self._optimizer

@property
def defaults(self):
return self._optimizer.defaults
Expand Down Expand Up @@ -102,11 +106,13 @@ def _on_trainer_init(self, trainer):
break

@classmethod
def to_lightning_optimizer(cls, optimizer, trainer):
if isinstance(optimizer, LightningOptimizer):
return optimizer
optimizer = cls(optimizer)
optimizer._on_trainer_init(trainer)
def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx):
# apex overrides .step function and need to be wrapped on each step
if trainer.amp_backend == AMPType.APEX:
optimizer = cls(optimizer)
optimizer._on_trainer_init(trainer)
else:
optimizer = trainer.lightning_optimizers[opt_idx]
return optimizer

def _accumulated_batches_reached(self):
Expand Down Expand Up @@ -148,7 +154,7 @@ def __optimizer_step(self, *args, closure: Optional[Callable] = None, profiler_n
**kwargs
)

trainer.train_loop.on_before_zero_grad(self)
trainer.train_loop.on_before_zero_grad(optimizer)

model.optimizer_zero_grad(
trainer.current_epoch,
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/plugins/ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from typing import Any, List, Optional

import torch
import torch.distributed as torch_distrib
from torch import nn
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
Expand All @@ -27,8 +27,8 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if FAIRSCALE_PIPE_AVAILABLE:
import fairscale.nn.model_parallel as mpu
from fairscale.nn import PipeRPCWrapper
import fairscale.nn.model_parallel as mpu
from fairscale.nn.pipe import balance as pipe_balance
from fairscale.nn.pipe import rpc as rpc_pipe
from fairscale.nn.pipe.pipeline import PipelineStyle
Expand Down Expand Up @@ -380,7 +380,6 @@ def register_optimizers(ctx, model):
model.trainer.optimizers = optimizers
model.trainer.lr_schedulers = lr_schedulers
model.trainer.optimizer_frequencies = optimizer_frequencies
model.trainer.convert_to_lightning_optimizers()


def run_optimizer(ctx, model):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
# unscale gradient to allow analyze within `on_after_backward`
if not self.trainer.train_loop.should_accumulate() and automatic_optimization:
if isinstance(optimizer, LightningOptimizer):
self.trainer.scaler.unscale_(optimizer._optimizer)
self.trainer.scaler.unscale_(optimizer.optimizer)
else:
self.trainer.scaler.unscale_(optimizer)

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _reinit_with_fairscale_oss(self, trainer):
optimizers = trainer.optimizers
for x, optimizer in enumerate(optimizers):
if is_lightning_optimizer(optimizer):
optimizer = optimizer._optimizer
optimizer = optimizer.optimizer
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(
Expand All @@ -73,7 +73,6 @@ def _reinit_with_fairscale_oss(self, trainer):
)
optimizers[x] = zero_optimizer
del optimizer
trainer.convert_to_lightning_optimizers()

def get_model_from_plugin(
self,
Expand Down
Loading

0 comments on commit 20b1abf

Please sign in to comment.