Skip to content

Commit

Permalink
Remove trainer.fit return value [2/n] (#7237)
Browse files Browse the repository at this point in the history
* `_fit_impl` refactor and types

* Fix return

* Remove return docstring

* Fixes

* Fixes

* Remove `trainer.fit` return value

* Update CHANGELOG

* flake8

* Undo results change

* Fix test

* Revert changes for a separate PR

* flake8
  • Loading branch information
carmocca authored Apr 28, 2021
1 parent bdc4272 commit 40f8023
Show file tree
Hide file tree
Showing 20 changed files with 35 additions and 72 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed legacy code to log or include metrics in the progress bar by returning them in a dict with the `"log"/"progress_bar"` magic keys. Use `self.log` instead ([#6734](https://github.com/PyTorchLightning/pytorch-lightning/pull/6734))


- Removed `trainer.fit()` return value of `1`. It has no return now ([#7237](https://github.com/PyTorchLightning/pytorch-lightning/pull/7237))


- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))


Expand Down
14 changes: 4 additions & 10 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def _launch(
train_dataloader: Any = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
) -> Union[int, _EVALUATE_OUTPUT, _PREDICT_OUTPUT]:
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
# set local properties on the model
self.model_connector.copy_trainer_model_properties(model)

Expand Down Expand Up @@ -497,9 +497,7 @@ def _launch(
self.state = TrainerState.FINISHED
self._running_stage = None

# return 1 when finished
# used for testing or when we need to know that training succeeded
return self.accelerator.results or 1
return self.accelerator.results

def pre_dispatch(self):
self.accelerator.pre_dispatch(self)
Expand Down Expand Up @@ -836,7 +834,7 @@ def fit(
train_dataloader: Any = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
) -> Optional[int]:
) -> None:
r"""
Runs the full optimization routine.
Expand All @@ -857,15 +855,11 @@ def fit(
self.state = TrainerState.FITTING
self.training = True

results = self._launch(
model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule
)
self._launch(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule)

assert self.state.stopped
self.training = False

return results

def validate(
self,
model: Optional[LightningModule] = None,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def before_fit(self) -> None:

def fit(self) -> None:
"""Runs fit of the instantiated trainer class and prepared fit keyword arguments"""
self.fit_result = self.trainer.fit(**self.fit_kwargs)
self.trainer.fit(**self.fit_kwargs)

def after_fit(self) -> None:
"""Implement to run some code after fit has finished"""
3 changes: 1 addition & 2 deletions tests/accelerators/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def test_evaluate(tmpdir, trainer_kwargs):
**trainer_kwargs
)

result = trainer.fit(model, datamodule=dm)
assert result
trainer.fit(model, datamodule=dm)
assert 'ckpt' in trainer.checkpoint_callback.best_model_path

old_weights = model.layer_0.weight.clone().detach().cpu()
Expand Down
9 changes: 3 additions & 6 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def test_weight_tying_warning(tmpdir, capsys=None):
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

with pytest.warns(UserWarning, match=r'The model layers do not match after moving to the target device.'):
result = trainer.fit(model)
assert result
trainer.fit(model)


# @RunIf(tpu=True)
Expand All @@ -106,8 +105,7 @@ def test_weight_tying_warning(tmpdir, capsys=None):
# Ensure no warning for parameter mismatch is thrown.
# """

# # TODO (kaushikb11): Add `paramter_validation` specific to
# # TPU Accelerators
# # TODO (kaushikb11): Add `parameter_validation` specific to TPU Accelerators
# class Model(WeightSharingModule):

# def on_post_move_to_device(self):
Expand All @@ -117,8 +115,7 @@ def test_weight_tying_warning(tmpdir, capsys=None):
# trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

# with pytest.warns(UserWarning) as warnings:
# result = trainer.fit(model)
# assert result
# trainer.fit(model)

# assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
# assert len(trainer.test(model)) == 1
3 changes: 1 addition & 2 deletions tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def on_train_epoch_end(self, outputs) -> None:

assert any(isinstance(c, CB) for c in trainer.callbacks)

results = trainer.fit(model)
assert results
trainer.fit(model)


def test_on_val_epoch_end_outputs(tmpdir):
Expand Down
9 changes: 3 additions & 6 deletions tests/callbacks/test_prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,18 @@ def write_on_epoch_end(self, *args, **kwargs):

cb = CustomPredictionWriter("batch_and_epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
assert cb.write_on_batch_end_called
assert cb.write_on_epoch_end_called
assert results == 1

cb = CustomPredictionWriter("batch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
assert cb.write_on_batch_end_called
assert not cb.write_on_epoch_end_called
assert results == 1

cb = CustomPredictionWriter("epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
assert not cb.write_on_batch_end_called
assert cb.write_on_epoch_end_called
assert results == 1
6 changes: 2 additions & 4 deletions tests/checkpointing/test_legacy_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,11 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):

model = DummyModel.load_from_checkpoint(path_ckpt)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=6)
result = trainer.fit(model)
assert result
trainer.fit(model)

# todo
# model = DummyModel()
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, resume_from_checkpoint=path_ckpt)
# result = trainer.fit(model)
# assert result
# trainer.fit(model)

sys.path = orig_sys_paths
6 changes: 2 additions & 4 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def configure_optimizers(self):
max_epochs=max_epochs,
progress_bar_refresh_rate=0,
)
results = trainer.fit(model)
assert results
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
Expand Down Expand Up @@ -232,8 +231,7 @@ def configure_optimizers(self):
progress_bar_refresh_rate=0,
num_sanity_val_steps=0,
)
results = trainer.fit(model)
assert results
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
Expand Down
9 changes: 3 additions & 6 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,8 @@ def test_train_loop_only(tmpdir):
)

# fit model
result = trainer.fit(model, datamodule=dm)
trainer.fit(model, datamodule=dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert result
assert trainer.callback_metrics['train_loss'] < 1.0


Expand All @@ -294,9 +293,8 @@ def test_train_val_loop_only(tmpdir):
)

# fit model
result = trainer.fit(model, datamodule=dm)
trainer.fit(model, datamodule=dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert result
assert trainer.callback_metrics['train_loss'] < 1.0


Expand Down Expand Up @@ -353,10 +351,9 @@ def test_full_loop(tmpdir):
)

# fit model
result = trainer.fit(model, dm)
trainer.fit(model, dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert dm.trainer is not None
assert result

# validate
result = trainer.validate(datamodule=dm)
Expand Down
4 changes: 1 addition & 3 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,7 @@ def optimizer_step(
accumulate_grad_batches=1,
limit_val_batches=0,
)

results = trainer.fit(model)
assert results
trainer.fit(model)


def test_toggle_untoggle_3_optimizers_shared_parameters(tmpdir):
Expand Down
9 changes: 3 additions & 6 deletions tests/core/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,11 @@ def test_dataloader(self):
assert not prediction_file.exists()

if do_train:
result = trainer.fit(model, dm)
trainer.fit(model, dm)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert result
result = trainer.test(datamodule=dm)
# TODO: add end-to-end test
# assert result[0]['test_loss'] < 0.6
trainer.test(datamodule=dm)
else:
result = trainer.test(model, datamodule=dm)
trainer.test(model, datamodule=dm)

# check prediction file now exists and is of expected length
assert prediction_file.exists()
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
callbacks=[checkpoint],
logger=logger,
)
_ = trainer.fit(model)
trainer.fit(model)

# correct result and ok accuracy
assert trainer.state == TrainerState.FINISHED, 'amp + ddp model failed to complete'
Expand Down
3 changes: 1 addition & 2 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,7 @@ def configure_optimizers(self):
limit_train_batches=0.2,
accelerator='horovod'
)
results = trainer.fit(model)
assert results == 1
trainer.fit(model)

adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups][0]
adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups][0]
Expand Down
4 changes: 1 addition & 3 deletions tests/trainer/flags/test_min_max_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_ste
max_steps=max_steps,
weights_summary=None,
)

result = trainer.fit(model)
assert result == 1, "Training did not complete"
trainer.fit(model)

# check training stopped at max_epochs or max_steps
if trainer.max_steps and not trainer.max_epochs:
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def test_lr_scheduler_strict(tmpdir):
with pytest.warns(
RuntimeWarning, match=r'ReduceLROnPlateau conditioned on metric .* which is not available but strict'
):
assert trainer.fit(model)
trainer.fit(model)


def test_unknown_configure_optimizers_raises(tmpdir):
Expand Down
3 changes: 1 addition & 2 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,8 +983,7 @@ def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_train
default_root_dir=tmpdir,
multiple_trainloader_mode=multiple_trainloader_mode,
)

assert 1 == trainer.fit(model)
trainer.fit(model)
# verify the num_training_batches according to the multiple_trainloader_mode
assert num_training_batches == trainer.num_training_batches

Expand Down
10 changes: 2 additions & 8 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,10 +759,9 @@ def validation_epoch_end(self, *args, **kwargs):
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)
trainer.fit(model)

# check that limit_val_batches=0 turns off validation
assert result == 1, "training failed to complete"
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert trainer.current_epoch == 1
assert not model.validation_step_invoked, "`validation_step` should not run when `limit_val_batches=0`"
Expand Down Expand Up @@ -1593,8 +1592,6 @@ def predict(
assert len(results) == 2
assert len(results[0]) == num_samples
assert results[0][0].shape == torch.Size([1, 2])
else:
assert results == 1


def test_trainer_predict_no_return(tmpdir):
Expand Down Expand Up @@ -1693,8 +1690,6 @@ def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir):
assert len(preds) == 1
assert preds[0].shape == torch.Size([1, 2])
assert preds[0].dtype == (torch.float64 if precision == 64 else torch.float32)
else:
assert preds == 1


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1735,7 +1730,7 @@ def training_epoch_end(self, *args, **kwargs):
max_epochs=5,
limit_train_batches=limit_train_batches,
)
result = trainer.fit(model, train_loader)
trainer.fit(model, train_loader)

params_string = f"""`limit_train_batches={limit_train_batches}`, `dataset_len={dataset_len}`
& `batch_size={batch_size}` as
Expand All @@ -1745,7 +1740,6 @@ def training_epoch_end(self, *args, **kwargs):
else:
error_string = f"should not run with {params_string}"

assert result == 1, "training failed to complete"
assert trainer.state == TrainerState.FINISHED
assert trainer.global_step == global_step
assert trainer.num_training_batches == num_training_batches
Expand Down
4 changes: 1 addition & 3 deletions tests/trainer/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,7 @@ def on_train_epoch_end(self, outputs):
progress_bar_refresh_rate=0,
weights_summary=None,
)

result = trainer.fit(model)
assert result == 1, "Training did not complete"
trainer.fit(model)


def test_training_starts_with_seed(tmpdir):
Expand Down
2 changes: 0 additions & 2 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def test_lightning_cli_args(tmpdir):
with mock.patch('sys.argv', ['any.py'] + cli_args):
cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={'callbacks': [LearningRateMonitor()]})

assert cli.fit_result == 1
assert cli.config['seed_everything'] == 1234
config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml'
assert os.path.isfile(config_path)
Expand Down Expand Up @@ -324,7 +323,6 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
trainer_defaults={'callbacks': LearningRateMonitor()}
)

assert cli.fit_result == 1
config_path = tmpdir / 'lightning_logs' / 'version_0' / 'config.yaml'
assert os.path.isfile(config_path)
with open(config_path) as f:
Expand Down

0 comments on commit 40f8023

Please sign in to comment.