Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enabled manual returns #4089

Merged
merged 1 commit into from
Oct 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,10 @@ def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] =
del out['meta']

for k, v in out.items():
# support manual opt where the user does not return a minimize key
if k == 'minimize' and v is None:
continue

if isinstance(v, dict):
in_d = result.get(k, {})
v = recursive_gather([v], in_d)
Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def training_epoch_end(self, model, epoch_output, num_optimizers):

epoch_output = self.__prepare_epoch_end_inputs(epoch_output)

if num_optimizers == 1:
if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization:
epoch_output = epoch_output[0]

# lightningmodule hook
Expand Down Expand Up @@ -447,11 +447,18 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output):
for train_step_idx in range(len(opt_outputs)):
tbptt_outs = opt_outputs[train_step_idx]
tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
time_reduced_outputs.append(tbptt_outs)
if len(tbptt_outs) > 1:
time_reduced_outputs.append(tbptt_outs)

if len(time_reduced_outputs) == 0:
continue

# reduce across training steps
opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs)
opt_outputs.minimize = opt_outputs.minimize.mean()

# with manual opt need 1+ metrics because meta is always there
if opt_outputs.minimize is not None:
opt_outputs.minimize = opt_outputs.minimize.mean()
epoch_log_metrics.update(opt_outputs.epoch_log_metrics)
epoch_progress_bar_metrics.update(opt_outputs.epoch_pbar_metrics)

Expand Down
43 changes: 26 additions & 17 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,17 +313,22 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
if training_step_output_for_epoch_end is None:
return None

# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
if is_result_obj:
closure_loss = training_step_output.minimize
else:
closure_loss = training_step_output.batch_loss
# enable empty loss when using manual opt
closure_loss = None
untouched_loss = None

if self.trainer.train_loop.automatic_optimization:
# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
if is_result_obj:
closure_loss = training_step_output.minimize
else:
closure_loss = training_step_output.batch_loss

closure_loss = closure_loss / self.trainer.accumulate_grad_batches
closure_loss = closure_loss / self.trainer.accumulate_grad_batches

# the loss will get scaled for amp. avoid any modifications to it
untouched_loss = closure_loss.detach().clone()
# the loss will get scaled for amp. avoid any modifications to it
untouched_loss = closure_loss.detach().clone()

# result
result = AttributeDict(
Expand Down Expand Up @@ -681,13 +686,16 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(opt_closure_result.loss)

# track total loss for logging (avoid mem leaks)
self.accumulated_loss.append(opt_closure_result.loss)

# track all the outputs across all steps
batch_opt_idx = opt_idx if len(batch_outputs) > 1 else 0
batch_outputs[batch_opt_idx].append(opt_closure_result.training_step_output_for_epoch_end)

if not self.automatic_optimization:
continue

# track total loss for logging (avoid mem leaks)
self.accumulated_loss.append(opt_closure_result.loss)

# ------------------------------
# BACKWARD PASS
# ------------------------------
Expand Down Expand Up @@ -748,12 +756,13 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
self.warning_cache.warn('training_step returned None if it was on purpose, ignore this warning...')
return None

# backward pass
with self.trainer.profiler.profile('model_backward'):
self.backward(result, optimizer, opt_idx)
if self.trainer.train_loop.automatic_optimization:
# backward pass
with self.trainer.profiler.profile('model_backward'):
self.backward(result, optimizer, opt_idx)

# hook
self.on_after_backward(result.training_step_output, batch_idx, result.loss)
# hook
self.on_after_backward(result.training_step_output, batch_idx, result.loss)

return result

Expand Down
135 changes: 135 additions & 0 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,141 @@ def configure_optimizers(self):
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls


def test_multiple_optimizers_manual_return(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'

"""
Tests that only training_step can be used
"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.optimizers()
loss_1 = self.step(batch[0])

# make sure there are no grads
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)

self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)

# fake discriminator
loss_2 = self.step(batch[0])

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)

assert self.layer.weight.grad is not None
opt_b.step()
opt_b.zero_grad()
assert torch.all(self.layer.weight.grad == 0)

return {'something': 'else'}

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2

model = TestModel()
model.val_dataloader = None

limit_train_batches = 2
trainer = Trainer(
automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
)

trainer.fit(model)

num_manual_backward_calls = 3
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls


def test_multiple_optimizers_manual_return_and_log(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'

"""
Tests that only training_step can be used
"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
# manual
(opt_a, opt_b) = self.optimizers()
loss_1 = self.step(batch[0])

# make sure there are no grads
if batch_idx > 0:
assert torch.all(self.layer.weight.grad == 0)

self.manual_backward(loss_1, opt_a)
opt_a.step()
opt_a.zero_grad()
assert torch.all(self.layer.weight.grad == 0)

# fake discriminator
loss_2 = self.step(batch[0])

# ensure we forward the correct params to the optimizer
# without retain_graph we can't do multiple backward passes
self.manual_backward(loss_2, opt_b, retain_graph=True)
self.manual_backward(loss_2, opt_a, retain_graph=True)
self.log('a', loss_2, on_epoch=True)

assert self.layer.weight.grad is not None
opt_b.step()
opt_b.zero_grad()
assert torch.all(self.layer.weight.grad == 0)

return {'something': 'else'}

def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2

model = TestModel()
model.val_dataloader = None

limit_train_batches = 2
trainer = Trainer(
automatic_optimization=False,
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
)

trainer.fit(model)

num_manual_backward_calls = 3
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * num_manual_backward_calls

expected = {'a', 'a_step', 'a_epoch', 'epoch'}
logged = set(trainer.logged_metrics.keys())
assert expected == logged


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_multiple_optimizers_manual_native_amp(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'
Expand Down