Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimizer clean up #4658

Merged
merged 136 commits into from
Dec 1, 2020
Merged
Show file tree
Hide file tree
Changes from 126 commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
943e5f1
add LightningOptimizer
tchaton Nov 13, 2020
c778036
typo
tchaton Nov 13, 2020
5d00c20
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 13, 2020
e4c5a2e
add mock closure
tchaton Nov 13, 2020
9681634
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 13, 2020
397e2b4
typo
tchaton Nov 13, 2020
3ac8ed1
remove logic in optimizer_step
tchaton Nov 13, 2020
9cf88bd
update
tchaton Nov 13, 2020
d952917
update
tchaton Nov 13, 2020
698034a
update
tchaton Nov 13, 2020
bbe743a
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 13, 2020
50af33d
desactivate LightningOptimizer for hovorod
tchaton Nov 13, 2020
99913b6
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 13, 2020
ad2a99a
resolve flake
tchaton Nov 13, 2020
59dec77
typo
tchaton Nov 13, 2020
49f3b2e
check optimizer name
tchaton Nov 13, 2020
3034eaa
change name
tchaton Nov 13, 2020
98a62c4
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 15, 2020
3bf605d
added backward to LightningOptimizer
tchaton Nov 16, 2020
9d13ca8
remove use_lightning_optimizer
tchaton Nov 16, 2020
11868e6
move update
tchaton Nov 16, 2020
083dffd
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 16, 2020
3aaeb81
simplify init
tchaton Nov 16, 2020
81bbfee
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 16, 2020
6dc1d00
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 16, 2020
28b2a8f
resolve comments
tchaton Nov 16, 2020
91545da
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 16, 2020
52436e4
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 18, 2020
1a05a9a
resolve bug
tchaton Nov 18, 2020
424849b
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 18, 2020
2cedc5c
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 18, 2020
5f9c6b0
update
tchaton Nov 18, 2020
c805655
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 18, 2020
d34b417
update
tchaton Nov 18, 2020
e62ada1
resolve bugs
tchaton Nov 18, 2020
b4ad74f
resolve flake8
tchaton Nov 18, 2020
c81e179
set state
tchaton Nov 18, 2020
8b322c1
work manual_optimizer_step
tchaton Nov 18, 2020
c593d08
add doc
tchaton Nov 18, 2020
027ba6f
add enable_pl_optimizer
tchaton Nov 18, 2020
0219c63
make optimizer_step
tchaton Nov 18, 2020
162b535
add make_optimizer_step
tchaton Nov 18, 2020
86069ff
add examples
tchaton Nov 18, 2020
5dbbccc
resolve test
tchaton Nov 18, 2020
e4f42bd
add test_optimizer_return_options_enable_pl_optimizer
tchaton Nov 18, 2020
e59d660
add enable_pl_optimizer=True
tchaton Nov 18, 2020
c8aebdd
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 18, 2020
82d2032
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 19, 2020
c110002
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 19, 2020
a37ad78
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 20, 2020
a5e8a18
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 20, 2020
d3fcdbb
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 20, 2020
525294b
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 23, 2020
5fbba23
update
tchaton Nov 23, 2020
0d25494
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 23, 2020
4f09984
update tests
tchaton Nov 23, 2020
d785936
resolve bugs
tchaton Nov 23, 2020
01591cb
update
tchaton Nov 23, 2020
f9627a8
set Trainer to False
tchaton Nov 23, 2020
f6b0a1f
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 23, 2020
9013985
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 23, 2020
be23cfb
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 24, 2020
d7e2080
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 25, 2020
ec09fc6
update
tchaton Nov 25, 2020
81be876
resolve bugs
tchaton Nov 25, 2020
04b5f18
update
tchaton Nov 25, 2020
b207dc0
remove from doc
tchaton Nov 25, 2020
be2c13f
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 25, 2020
cdc56e6
resolve bug
tchaton Nov 25, 2020
7cb0fe1
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 25, 2020
393ef1a
typo
tchaton Nov 25, 2020
4eb11bf
update
tchaton Nov 26, 2020
34f6854
set to True
tchaton Nov 26, 2020
d24e714
simplification
tchaton Nov 26, 2020
e2ea18e
typo
tchaton Nov 26, 2020
c55df45
resolve horovod
tchaton Nov 26, 2020
728ef64
unwrap horovod
tchaton Nov 26, 2020
b0c220a
remove Optimizer
tchaton Nov 26, 2020
49d629b
resolve horovod
tchaton Nov 26, 2020
8bdc9bd
move logic to amp_backend
tchaton Nov 26, 2020
fc5f149
doesn't seem to be pickable
tchaton Nov 26, 2020
1d0bf6c
update
tchaton Nov 26, 2020
f66fc8d
add again
tchaton Nov 26, 2020
0a03907
resolve some bugs
tchaton Nov 26, 2020
f27c046
cleanup
tchaton Nov 26, 2020
b10723f
resolve bug with AMP
tchaton Nov 26, 2020
8e6f63f
change __repr__
tchaton Nov 26, 2020
fe0c876
round at -12
tchaton Nov 26, 2020
0c39f3b
udpate
tchaton Nov 26, 2020
9ce8ec5
update
tchaton Nov 26, 2020
9d9e5f7
update
tchaton Nov 26, 2020
193b4ec
remove from horovod
tchaton Nov 26, 2020
30a352f
typo
tchaton Nov 26, 2020
0fbe742
add convert_to_lightning_optimizers in each accelerators
tchaton Nov 26, 2020
32e4769
typo
tchaton Nov 26, 2020
d9bfe2e
forgot
tchaton Nov 26, 2020
81dbd40
forgot a convert_to_lightning_optimizers
tchaton Nov 26, 2020
9de6597
update
tchaton Nov 26, 2020
b629248
update
tchaton Nov 26, 2020
9db4a69
update
tchaton Nov 26, 2020
a97789b
increase coverage
tchaton Nov 26, 2020
1b63d43
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 26, 2020
2240bd7
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 27, 2020
0f6f50f
update
tchaton Nov 27, 2020
0b3b9d5
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 27, 2020
9743193
resolve flake8
tchaton Nov 27, 2020
9527bd4
update
tchaton Nov 27, 2020
58164de
remove useless code
tchaton Nov 27, 2020
8c47f28
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 27, 2020
98116af
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 27, 2020
600d61c
resolve comments + add support for LightningOptimizer base class
tchaton Nov 28, 2020
5d7b7d0
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 28, 2020
f96a0d4
resolve flake
tchaton Nov 28, 2020
897a400
check optimizer get wrapped back
tchaton Nov 28, 2020
18e109c
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 28, 2020
bda1f4f
resolve DDPSharded
tchaton Nov 28, 2020
8448b93
reduce code
tchaton Nov 28, 2020
8b22417
lightningoptimizer
tchaton Nov 28, 2020
ae8b6f0
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 29, 2020
9fcf871
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 30, 2020
e3343a7
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 30, 2020
a769a2c
Update pytorch_lightning/core/optimizer.py
tchaton Nov 30, 2020
3fe456b
Update pytorch_lightning/core/lightning.py
tchaton Nov 30, 2020
0b6bb75
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 30, 2020
d599e84
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 30, 2020
54f6ff5
remove reference to step function
tchaton Nov 30, 2020
d8c6817
Apply suggestions from code review
Borda Nov 30, 2020
8224ec6
update on comments
tchaton Nov 30, 2020
dc0610d
Merge branch 'bugfix/4572_optimizer_step' of https://github.com/PyTor…
tchaton Nov 30, 2020
44985b1
resolve
tchaton Nov 30, 2020
09b9ebd
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 30, 2020
abff2ac
Update CHANGELOG.md
williamFalcon Nov 30, 2020
a69881f
add back training_step in apex and native_amp
tchaton Nov 30, 2020
b231fd8
rename optimizer_step
tchaton Nov 30, 2020
c4023dc
Merge branch 'master' into bugfix/4572_optimizer_step
tchaton Nov 30, 2020
53b6921
Merge branch 'master' into bugfix/4572_optimizer_step
SeanNaren Nov 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added printing of total num of params, trainable and non-trainable params in ModelSummary ([#4521](https://github.com/PyTorchLightning/pytorch-lightning/pull/4521))

- Added `LightningOptimizer` ([#4658](https://github.com/PyTorchLightning/pytorch-lightning/pull/4658))


### Changed

Expand Down
5 changes: 0 additions & 5 deletions docs/source/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1009,11 +1009,6 @@ manual_backward
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward
:noindex:

manual_optimizer_step
~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step
:noindex:

on_after_backward
~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions docs/source/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ to manually manage the optimization process. To do so, do the following:

# use self.backward which will also handle scaling the loss when using amp
self.manual_backward(loss_a, opt_g)
self.manual_optimizer_step(opt_g)
opt_g.step()


# do anything you want
Expand All @@ -45,7 +45,7 @@ to manually manage the optimization process. To do so, do the following:
# pass in any args that loss.backward() normally takes
self.manual_backward(loss_b, opt_d, retain_graph=True)
self.manual_backward(loss_b, opt_d)
self.manual_optimizer_step(opt_d)
opt_d.step()

# log losses
self.log('loss_a', loss_a)
Expand Down
15 changes: 7 additions & 8 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import accelerators
import os

import torch

from pytorch_lightning.utilities import device_parser, XLA_AVAILABLE
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning import _logger as log
from pytorch_lightning import accelerators
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities import XLA_AVAILABLE, device_parser, rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
import horovod.torch as hvd
Expand Down Expand Up @@ -397,8 +397,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
def determine_local_rank(self):
if self.trainer.is_slurm_managing_tasks:
return int(os.environ['SLURM_LOCALID'])
else:
return int(os.environ.get('LOCAL_RANK', 0))
return int(os.environ.get('LOCAL_RANK', 0))

def determine_ddp_node_rank(self):
if self.trainer.is_slurm_managing_tasks:
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union, Any
from typing import Any, Optional, Union

import torch

Expand Down Expand Up @@ -47,6 +47,8 @@ def setup(self, model):
# allow for lr schedulers as well
self.setup_optimizers(model)

self.trainer.convert_to_lightning_optimizers()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

self.trainer.model = model

def train(self):
Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
from torch.nn.parallel import DistributedDataParallel
from typing import List, Optional, Union, Any

if HYDRA_AVAILABLE:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path


class DDP2Accelerator(Accelerator):
Expand Down Expand Up @@ -170,6 +170,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
19 changes: 9 additions & 10 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
import torch
import torch.distributed as torch_distrib
import subprocess
import sys
from os.path import abspath
from time import sleep
from typing import Any, Optional, List, Union
from typing import Any, List, Optional, Union

import numpy as np
import torch
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from torch.nn.parallel import DistributedDataParallel


if HYDRA_AVAILABLE:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path


class DDPAccelerator(Accelerator):
Expand Down Expand Up @@ -266,6 +263,8 @@ def ddp_train(self, process_idx, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
12 changes: 9 additions & 3 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import find_free_network_port, sync_ddp_if_available
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import (
find_free_network_port,
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
)

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -130,6 +134,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# DDP spawn already spawned off each process... no need to do anything
device_ids = self.get_device_ids()

Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available


if HYDRA_AVAILABLE:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path


class DDPHPCAccelerator(Accelerator):
Expand Down Expand Up @@ -164,6 +163,8 @@ def ddp_train(self, process_idx, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
19 changes: 13 additions & 6 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@
from typing import Any, List, Optional, Union

import torch
import torch.multiprocessing as mp
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, find_free_network_port
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
find_free_network_port,
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed

if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
Expand Down Expand Up @@ -141,6 +146,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import torch
from torch import optim

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -63,6 +63,8 @@ def setup(self, model):
if self.trainer.amp_backend:
model = self.__init_half_precision(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def __init_torch_data_parallel(self, model):
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional, Any
from typing import Any, Optional, Union

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType


class GPUAccelerator(Accelerator):
Expand Down Expand Up @@ -52,6 +52,8 @@ def setup(self, model):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

self.trainer.model = model

def train(self):
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def _filter_named_parameters(model, optimizer):
# 16-bit
model = self.trainer.precision_connector.connect(model)

self.trainer.convert_to_lightning_optimizers()

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.trainer.global_rank = hvd.rank()
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import io
import os
import re
from typing import Optional, Union, Any
from typing import Any, Optional, Union

import torch
import torch.multiprocessing as mp
Expand All @@ -23,7 +23,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, TPU_AVAILABLE
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -230,6 +230,8 @@ def __setup_tpu_training(self, model: LightningModule, trainer):
f' global rank: {trainer.tpu_global_core_rank}'
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')

self.trainer.convert_to_lightning_optimizers()

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
# do backward pass
if self.trainer.train_loop.automatic_optimization:
Expand Down
Loading