Skip to content

Commit

Permalink
optimizer clean up (#4658)
Browse files Browse the repository at this point in the history
* add LightningOptimizer

* typo

* add mock closure

* typo

* remove logic in optimizer_step

* update

* update

* update

* desactivate LightningOptimizer for hovorod

* resolve flake

* typo

* check optimizer name

* change name

* added backward to LightningOptimizer

* remove use_lightning_optimizer

* move update

* simplify init

* resolve comments

* resolve bug

* update

* update

* resolve bugs

* resolve flake8

* set state

* work manual_optimizer_step

* add doc

* add enable_pl_optimizer

* make optimizer_step

* add make_optimizer_step

* add examples

* resolve test

* add test_optimizer_return_options_enable_pl_optimizer

* add enable_pl_optimizer=True

* update

* update tests

* resolve bugs

* update

* set Trainer to False

* update

* resolve bugs

* update

* remove from doc

* resolve bug

* typo

* update

* set to True

* simplification

* typo

* resolve horovod

* unwrap horovod

* remove Optimizer

* resolve horovod

* move logic to amp_backend

* doesn't seem to be pickable

* update

* add again

* resolve some bugs

* cleanup

* resolve bug with AMP

* change __repr__

* round at -12

* udpate

* update

* update

* remove from horovod

* typo

* add convert_to_lightning_optimizers in each accelerators

* typo

* forgot

* forgot a convert_to_lightning_optimizers

* update

* update

* update

* increase coverage

* update

* resolve flake8

* update

* remove useless code

* resolve comments + add support for LightningOptimizer base class

* resolve flake

* check optimizer get wrapped back

* resolve DDPSharded

* reduce code

* lightningoptimizer

* Update pytorch_lightning/core/optimizer.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/core/lightning.py

* remove reference to step function

* Apply suggestions from code review

* update on comments

* resolve

* Update CHANGELOG.md

* add back training_step in apex and native_amp

* rename optimizer_step

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
  • Loading branch information
5 people authored Dec 1, 2020
1 parent 2fe1eff commit c2e6e68
Show file tree
Hide file tree
Showing 36 changed files with 741 additions and 332 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added printing of total num of params, trainable and non-trainable params in ModelSummary ([#4521](https://github.com/PyTorchLightning/pytorch-lightning/pull/4521))

- Added optimizer refactors ([#4658](https://github.com/PyTorchLightning/pytorch-lightning/pull/4658))


### Changed

Expand Down
5 changes: 0 additions & 5 deletions docs/source/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1009,11 +1009,6 @@ manual_backward
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward
:noindex:

manual_optimizer_step
~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step
:noindex:

on_after_backward
~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ to manually manage the optimization process. To do so, do the following:
# use self.backward which will also handle scaling the loss when using amp
self.manual_backward(loss_a, opt_g)
self.manual_optimizer_step(opt_g)
opt_g.step()
# do anything you want
Expand All @@ -45,7 +45,7 @@ to manually manage the optimization process. To do so, do the following:
# pass in any args that loss.backward() normally takes
self.manual_backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d)
self.manual_optimizer_step(opt_d)
opt_d.step()
# log losses
self.log('loss_a', loss_a)
Expand Down
15 changes: 7 additions & 8 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import accelerators
import os

import torch

from pytorch_lightning.utilities import device_parser, XLA_AVAILABLE
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning import _logger as log
from pytorch_lightning import accelerators
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import XLA_AVAILABLE, device_parser, rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
import horovod.torch as hvd
Expand Down Expand Up @@ -397,8 +397,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
def determine_local_rank(self):
if self.trainer.is_slurm_managing_tasks:
return int(os.environ['SLURM_LOCALID'])
else:
return int(os.environ.get('LOCAL_RANK', 0))
return int(os.environ.get('LOCAL_RANK', 0))

def determine_ddp_node_rank(self):
if self.trainer.is_slurm_managing_tasks:
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union, Any
from typing import Any, Optional, Union

import torch

Expand Down Expand Up @@ -47,6 +47,8 @@ def setup(self, model):
# allow for lr schedulers as well
self.setup_optimizers(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def train(self):
Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
from torch.nn.parallel import DistributedDataParallel
from typing import List, Optional, Union, Any

if HYDRA_AVAILABLE:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path


class DDP2Accelerator(Accelerator):
Expand Down Expand Up @@ -170,6 +170,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
19 changes: 9 additions & 10 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
import torch
import torch.distributed as torch_distrib
import subprocess
import sys
from os.path import abspath
from time import sleep
from typing import Any, Optional, List, Union
from typing import Any, List, Optional, Union

import numpy as np
import torch
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from torch.nn.parallel import DistributedDataParallel


if HYDRA_AVAILABLE:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path


class DDPAccelerator(Accelerator):
Expand Down Expand Up @@ -266,6 +263,8 @@ def ddp_train(self, process_idx, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
12 changes: 9 additions & 3 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import find_free_network_port, sync_ddp_if_available
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import (
find_free_network_port,
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
)

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -130,6 +134,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# DDP spawn already spawned off each process... no need to do anything
device_ids = self.get_device_ids()

Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available


if HYDRA_AVAILABLE:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path


class DDPHPCAccelerator(Accelerator):
Expand Down Expand Up @@ -164,6 +163,8 @@ def ddp_train(self, process_idx, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
19 changes: 13 additions & 6 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@
from typing import Any, List, Optional, Union

import torch
import torch.multiprocessing as mp
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, find_free_network_port
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
find_free_network_port,
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -141,6 +146,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import torch
from torch import optim

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -63,6 +63,8 @@ def setup(self, model):
if self.trainer.amp_backend:
model = self.__init_half_precision(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def __init_torch_data_parallel(self, model):
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional, Any
from typing import Any, Optional, Union

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType


class GPUAccelerator(Accelerator):
Expand Down Expand Up @@ -52,6 +52,8 @@ def setup(self, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def train(self):
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def _filter_named_parameters(model, optimizer):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.trainer.global_rank = hvd.rank()
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import io
import os
import re
from typing import Optional, Union, Any
from typing import Any, Optional, Union

import torch
import torch.multiprocessing as mp
Expand All @@ -23,7 +23,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, TPU_AVAILABLE
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -230,6 +230,8 @@ def __setup_tpu_training(self, model: LightningModule, trainer):
f' global rank: {trainer.tpu_global_core_rank}'
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')

self.trainer.convert_to_lightning_optimizers()

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
# do backward pass
if self.trainer.train_loop.automatic_optimization:
Expand Down
Loading

0 comments on commit c2e6e68

Please sign in to comment.