Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: precision plugins 1/n #3504

Merged
merged 2 commits into from
Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,9 @@

import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase
from pytorch_lightning.plugins.apex import ApexPlugin

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -38,7 +34,6 @@ class DDP2Backend(DDPBase):
def __init__(self, trainer):
super().__init__(trainer)
self.task_idx = None
self.precision_backend = None

def setup(self, model):
self._resolve_task_idx()
Expand Down
9 changes: 1 addition & 8 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
import numpy as np
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase

try:
Expand All @@ -35,11 +33,6 @@
else:
HYDRA_AVAILABLE = True

try:
from apex import amp
except ImportError:
amp = None


class DDPBackend(DDPBase):

Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/accelerators/ddp_base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.apex import ApexPlugin

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -37,7 +36,6 @@ class DDPBase(Accelerator):

def __init__(self, trainer):
super().__init__(trainer)
self.precision_backend = None

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
Expand Down Expand Up @@ -151,9 +149,7 @@ def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offs

# AMP -
# run through amp wrapper before going to distributed DP
if self.trainer.amp_backend == AMPType.APEX:
self.precision_backend = ApexPlugin(self.trainer)
model, optimizers = self.precision_backend._init(model)
model, optimizers = self.trainer.precision_connector.connect(model, optimizers)

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()
Expand Down
5 changes: 0 additions & 5 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase

try:
from apex import amp
except ImportError:
amp = None


class DDPSpawnBackend(DDPBase):

Expand Down
10 changes: 1 addition & 9 deletions pytorch_lightning/accelerators/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,13 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.plugins.apex import ApexPlugin

try:
from apex import amp
except ImportError:
amp = None


class DataParallelBackend(Accelerator):

def __init__(self, trainer):
super().__init__(trainer)
self.model_autocast_original_forward = None
self.precision_backend = None

def setup(self, model):
# call setup after the ddp process has connected
Expand Down Expand Up @@ -91,8 +84,7 @@ def __init_nvidia_apex(self, model):
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
f' We recommend you switch to ddp if you want to use amp')
else:
self.precision_backend = ApexPlugin(self.trainer)
model, optimizers = self.precision_backend._init(model)
model, optimizers = self.trainer.precision_connector.connect(model, self.trainer.optimizers)

return model

Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/accelerators/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@
import torch
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.plugins.apex import ApexPlugin


class GPUBackend(Accelerator):
amp_backend: AMPType

def __init__(self, trainer):
super().__init__(trainer)
self.precision_backend = None

def setup(self, model):

Expand All @@ -40,9 +38,8 @@ def setup(self, model):
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

if self.trainer.amp_backend == AMPType.APEX:
self.precision_backend = ApexPlugin(self.trainer)
model, optimizers = self.precision_backend._init(model)
# init precision
model, optimizers = self.trainer.precision_connector.connect(model, optimizers)

self.trainer.model = model

Expand Down
8 changes: 2 additions & 6 deletions pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
# limitations under the License.
from contextlib import ExitStack
import torch
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.optim.lr_scheduler import _LRScheduler
from pytorch_lightning.plugins.apex import ApexPlugin

try:
import horovod.torch as hvd
Expand All @@ -33,7 +31,6 @@ class HorovodBackend(Accelerator):

def __init__(self, trainer):
super().__init__(trainer)
self.precision_backend = None

def setup(self, model):
# call setup after the ddp process has connected
Expand Down Expand Up @@ -83,9 +80,8 @@ def filter_named_parameters(model, optimizer):
for optimizer in self.trainer.optimizers
]

if self.trainer.amp_backend == AMPType.APEX:
self.precision_backend = ApexPlugin(self.trainer)
model, optimizers = self.precision_backend._init(model)
# 16-bit
model, self.trainer.optimizers = self.trainer.precision_connector.connect(model, self.trainer.optimizers)

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class ApexPlugin:
def __init__(self, trainer):
self.trainer = trainer

def _init(self, model):
model, optimizers = self.configure_apex(model, self.trainer.optimizers, self.trainer.amp_level)
def connect(self, model, optimizers):
model, optimizers = self.configure_apex(model, optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
return model, optimizers
Expand Down
29 changes: 29 additions & 0 deletions pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


class NativeAMP:

def __init__(self, trainer):
self.trainer = trainer

def connect(self, model, optimizers):
self.trainer.optimizers = optimizers
return model, optimizers

def training_step(self, fx, args):
with torch.cuda.amp.autocast():
output = fx(*args)
return output
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/connectors/precision_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
# limitations under the License.
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType
from pytorch_lightning.plugins.native_amp import NativeAMP
from pytorch_lightning.plugins.apex import ApexPlugin


class PrecisionConnector:

def __init__(self, trainer):
self.trainer = trainer
self.backend = None

def on_trainer_init(self, precision, amp_level, amp_backend):
# AMP init
Expand Down Expand Up @@ -52,15 +55,25 @@ def _setup_amp_backend(self, amp_type: str):
else:
log.info('Using native 16bit precision.')
self.trainer.amp_backend = AMPType.NATIVE
self.backend = NativeAMP(self.trainer)

if amp_type == 'apex':
if not APEX_AVAILABLE:
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
else:
log.info('Using APEX 16bit precision.')
self.trainer.amp_backend = AMPType.APEX
self.backend = ApexPlugin(self.trainer)

if not self.trainer.amp_backend:
raise ModuleNotFoundError(
f'You have asked for AMP support {amp_type}, but there is no support on your side yet.'
f' Consider installing torch >= 1.6 or NVIDIA Apex.'
)

def connect(self, model, optimizers):
if self.backend:
model, optimizers = self.backend.connect(model, optimizers)

return model, optimizers