Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trainer.predict(ckpt_path) #7430

Merged
merged 7 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Added argument `trainer.predict(ckpt_path)` ([#7430](https://github.com/PyTorchLightning/pytorch-lightning/pull/7430))


- Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))


Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def should_store_predictions(self) -> bool:

def on_trainer_init(self):
self.trainer.num_predict_batches = []
self.trainer.predicted_ckpt_path = None

def get_predict_dataloaders(self):
self.trainer.reset_predict_dataloader(self.trainer.lightning_module)
Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def __init__(
num_sanity_val_steps,
)
self.evaluation_loop.on_trainer_init()
self.predict_loop.on_trainer_init()

# configure tuner
self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)
Expand Down Expand Up @@ -589,6 +590,7 @@ def predict(
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = 'best',
carmocca marked this conversation as resolved.
Show resolved Hide resolved
) -> Optional[_PREDICT_OUTPUT]:
r"""

Expand All @@ -605,6 +607,10 @@ def predict(
return_predictions: Whether to return predictions.
``True`` by default except when an accelerator that spawns processes is used (not supported).

ckpt_path: Either ``best`` or path to the checkpoint you wish to use to predict.
carmocca marked this conversation as resolved.
Show resolved Hide resolved
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.

Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
"""
Expand All @@ -614,8 +620,6 @@ def predict(
# --------------------
Trainer._log_api_event("predict")

model = model or self.lightning_module

self.state.fn = TrainerFn.PREDICTING
self.state.status = TrainerStatus.RUNNING
self.predicting = True
Expand All @@ -625,9 +629,15 @@ def predict(
if dataloaders is not None and datamodule:
raise MisconfigurationException('You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`')

model_provided = model is not None
model = model or self.lightning_module

# links data to the trainer
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

if not model_provided:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use model directly. No need for model_provided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? We use it also in validate and test

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton it canot as you do the mapping after it... I quess

self.predicted_ckpt_path = self.__load_ckpt_weights(ckpt_path)

results = self._run(model)

assert self.state.stopped
Expand Down
41 changes: 24 additions & 17 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def test_benchmark_option(tmpdir):

@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
@pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2))
@pytest.mark.parametrize("fn", ("validate", "test"))
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn):

class TestModel(BoringModel):
Expand All @@ -620,48 +620,55 @@ def validation_step(self, batch, batch_idx):
self.log("foo", -batch_idx)
return super().validation_step(batch, batch_idx)

def test_step(self, *args):
return self.validation_step(*args)

def predict_step(self, *args):
args = args[:-1] # remove `dataloader_idx`
return self.validation_step(*args)

model = TestModel()
model.test_epoch_end = None
trainer = Trainer(
max_epochs=2,
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
progress_bar_refresh_rate=0,
default_root_dir=tmpdir,
callbacks=[ModelCheckpoint(monitor="foo", save_top_k=save_top_k)],
)
trainer.fit(model)

test_or_validate = getattr(trainer, fn)
trainer_fn = getattr(trainer, fn)
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
assert getattr(trainer, path_attr) is None
Comment on lines +643 to +645
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels fragile. can we be more explicit here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test length would grow considerably (might as well split it into 3 tests then) if we don't use getattr. I know dynamic is ugly but I'd rather have a concise test


if ckpt_path == "best":
# ckpt_path is 'best', meaning we load the best weights
if save_top_k == 0:
with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"):
test_or_validate(ckpt_path=ckpt_path)
trainer_fn(ckpt_path=ckpt_path)
else:
test_or_validate(ckpt_path=ckpt_path)
if fn == "test":
assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path
else:
assert trainer.validated_ckpt_path == trainer.checkpoint_callback.best_model_path
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
elif ckpt_path is None:
# ckpt_path is None, meaning we don't load any checkpoints and
# use the weights from the end of training
test_or_validate(ckpt_path=ckpt_path)
assert trainer.tested_ckpt_path is None
assert trainer.validated_ckpt_path is None
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) is None
else:
# specific checkpoint, pick one from saved ones
if save_top_k == 0:
with pytest.raises(FileNotFoundError):
test_or_validate(ckpt_path="random.ckpt")
trainer_fn(ckpt_path="random.ckpt")
else:
ckpt_path = str(
list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir()
)[0].absolute()
)
test_or_validate(ckpt_path=ckpt_path)
if fn == "test":
assert trainer.tested_ckpt_path == ckpt_path
else:
assert trainer.validated_ckpt_path == ckpt_path
trainer_fn(ckpt_path=ckpt_path)
assert getattr(trainer, path_attr) == ckpt_path


def test_disabled_training(tmpdir):
Expand Down