Skip to content

Commit

Permalink
ref: decouple apex second attemp part 3/n (#4055)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Oct 10, 2020
1 parent 7285613 commit 3a6717c
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 26 deletions.
33 changes: 7 additions & 26 deletions pytorch_lightning/accelerators/base_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,33 +72,14 @@ def process_dataloader(self, dataloader):
return dataloader

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):

# scale loss for 16 bit
if self.trainer.precision == 16:
if self.trainer.amp_backend == AMPType.NATIVE:
closure_loss = self.trainer.scaler.scale(closure_loss)
else:
closure_loss = amp.scale_loss(closure_loss, optimizer)

# enter amp context
if self.trainer.amp_backend == AMPType.APEX:
self.trainer.dev_debugger.track_event('AMP', str(AMPType.APEX))
context = closure_loss
closure_loss = closure_loss.__enter__()

# do backward pass
closure_loss.backward(*args, **kwargs)

# exit amp context
if self.trainer.precision == 16 and self.trainer.amp_backend == AMPType.APEX:
a, b, c = None, None, None
error = context.__exit__(a, b, c)
if error:
rank_zero_warn(a, b, c)
raise Exception('apex unscale error')

# once backward has been applied, release graph
closure_loss = closure_loss.detach()
closure_loss = self.trainer.precision_connector.backend.backward(closure_loss, optimizer, *args, **kwargs)
else:
# do backward pass
closure_loss.backward(*args, **kwargs)

# once backward has been applied, release graph
closure_loss = closure_loss.detach()
return closure_loss

def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
Expand Down
24 changes: 24 additions & 0 deletions pytorch_lightning/plugins/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
from typing import List, Tuple
from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities import AMPType

try:
from apex import amp
Expand All @@ -34,6 +36,28 @@ def training_step(self, fx, args):
output = fx(args)
return output

def backward(self, closure_loss, optimizer, *args, **kwargs):
closure_loss = amp.scale_loss(closure_loss, optimizer)

# enter apex context
self.trainer.dev_debugger.track_event('AMP', str(AMPType.APEX))
context = closure_loss
closure_loss = closure_loss.__enter__()

# do backward pass
closure_loss.backward(*args, **kwargs)

# exit amp context
a, b, c = None, None, None
error = context.__exit__(a, b, c)
if error:
rank_zero_warn(a, b, c)
raise Exception('apex unscale error')

# once backward has been applied, release graph
closure_loss = closure_loss.detach()
return closure_loss

def configure_apex(
self,
amp: object,
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/plugins/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ def __init__(self, trainer):
def connect(self, model, optimizers):
return model, optimizers

def backward(self, closure_loss, optimizer, *args, **kwargs):
closure_loss = self.trainer.scaler.scale(closure_loss)

# do backward pass
closure_loss.backward(*args, **kwargs)

# once backward has been applied, release graph
closure_loss = closure_loss.detach()
return closure_loss

def training_step(self, fx, args):
with torch.cuda.amp.autocast():
output = fx(*args)
Expand Down

0 comments on commit 3a6717c

Please sign in to comment.