Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer.validate(…) method to run one validation epoch #4948

Merged
merged 45 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
edb3e83
Refactor Trainer in advance of implementing Trainer.validate
EliaCereda Dec 2, 2020
03d7994
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 2, 2020
5a54485
Add Trainer.validate(...) method to perform one evaluation epoch over…
EliaCereda Dec 2, 2020
e06775c
Rename methods in Trainer and Accelerator to reflect that they are us…
EliaCereda Dec 2, 2020
b4e409c
Update docs to mention the new Trainer.validate method and associated…
EliaCereda Dec 2, 2020
96e42ba
Add tests for Trainer.validate(…)
EliaCereda Dec 2, 2020
85b3c9f
Update CHANGELOG.md
EliaCereda Dec 2, 2020
39113dc
Merge branch 'master' into feature/trainer-validate-2
tchaton Dec 3, 2020
a6be0d8
Replace usages of Trainer.testing with Trainer.evaluating, should be …
EliaCereda Dec 4, 2020
a922d57
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 4, 2020
595f4e8
Clean up calls to LightningDataModule.setup()
EliaCereda Dec 8, 2020
0b09248
Update test_trainer_validate_loop.py to use BoringModel instead of Ev…
EliaCereda Dec 8, 2020
d691d79
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 8, 2020
52eaa70
Merge branch 'feature/trainer-validate-1' into feature/trainer-valida…
EliaCereda Dec 8, 2020
06b4419
Fix ShardedPlugin when evaluating
EliaCereda Dec 8, 2020
e6a8be9
Merge remote-tracking branch 'origin/feature/trainer-validate-1' into…
EliaCereda Dec 8, 2020
389940e
Add tests for Trainer.validate with ShardedPlugin
EliaCereda Dec 8, 2020
6d0a95a
Remove superfluous calls to LoggerConnector.set_stage in validate() a…
EliaCereda Dec 10, 2020
704b121
Update more docstrings to mention Trainer.validate
EliaCereda Dec 10, 2020
f6e0759
Merge branch 'release/1.2-dev' into feature/trainer-validate-1
tchaton Jan 11, 2021
90e59c7
Merge remote-tracking branch 'upstream/release/1.2-dev' into feature/…
EliaCereda Jan 26, 2021
45d7e0a
Merge branch 'feature/trainer-validate-1' into feature/trainer-valida…
EliaCereda Jan 26, 2021
12a85b3
Pass {fit,validate,test,predict} to setup()
carmocca Mar 7, 2021
d49ccd1
Fix doctest
carmocca Mar 7, 2021
23db135
stage: Optional[str] = None
carmocca Mar 7, 2021
84f5fdb
Trailing whitespace
carmocca Mar 7, 2021
188b9fe
Update docs and CHANGELOG
carmocca Mar 7, 2021
37473f0
Mention teardown
carmocca Mar 7, 2021
0a30abf
Self-review
carmocca Mar 7, 2021
0e9d69c
Address Borda's comments
carmocca Mar 7, 2021
04343ce
Merge branch 'deleteme-carmocca' into feature/trainer-validate-2
carmocca Mar 7, 2021
9758c7b
Fixing conflicts
carmocca Mar 7, 2021
18280df
Implement Trainer.validate
carmocca Mar 7, 2021
e582d58
Refactor
carmocca Mar 7, 2021
1a5b620
Merge branch 'master' into feature/trainer-validate-2
carmocca Mar 8, 2021
5b99ec0
flake8
carmocca Mar 8, 2021
9f4dce2
Refactor
carmocca Mar 8, 2021
088d4bc
Missing import
carmocca Mar 8, 2021
58fcca4
Fix test
carmocca Mar 8, 2021
babb73d
Same threshold
carmocca Mar 8, 2021
235dc27
Address tchaton's comments
carmocca Mar 8, 2021
73dd265
Merge branch 'master' into feature/trainer-validate-2
carmocca Mar 8, 2021
e423b98
Missing import
carmocca Mar 10, 2021
cdec83b
Merge branch 'master' into feature/trainer-validate-2
carmocca Mar 10, 2021
8fab50f
Apply suggestions from code review
carmocca Mar 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))


- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


Expand Down
15 changes: 14 additions & 1 deletion docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,27 @@ So you can run it like so:

------------

Validation
----------
You can perform an evaluation epoch over the validation set, outside of the training loop,
using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be
useful if you want to collect new metrics from a model right at its initialization
or after it has already been trained.

.. code-block:: python

trainer.validate(val_dataloaders=val_dataloaders)

------------

Testing
-------
Once you're done training, feel free to run the test set!
(Only right before publishing your paper or pushing to production)

.. code-block:: python

trainer.test(test_dataloaders=test_dataloader)
trainer.test(test_dataloaders=test_dataloaders)

------------

Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,11 @@ def init_predict_tqdm(self) -> tqdm:

def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.main_progress_bar is not None
bar = tqdm(
desc='Validating',
position=(2 * self.process_position + 1),
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
Expand Down Expand Up @@ -426,7 +428,8 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
if self.main_progress_bar is not None:
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
self.val_progress_bar.close()

def on_train_end(self, trainer, pl_module):
Expand Down Expand Up @@ -479,8 +482,10 @@ def print(
def _should_update(self, current, total):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def _update_bar(self, bar):
def _update_bar(self, bar: Optional[tqdm]) -> None:
""" Updates the bar by the refresh rate without overshooting. """
if bar is None:
return
if bar.total is not None:
delta = min(self.refresh_rate, bar.total - bar.n)
else:
Expand Down
21 changes: 13 additions & 8 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -22,18 +23,24 @@ class ConfigValidator(object):
def __init__(self, trainer):
self.trainer = trainer

def verify_loop_configurations(self, model: LightningModule):
def verify_loop_configurations(self, model: LightningModule) -> None:
r"""
Checks that the model is configured correctly before the run is started.

Args:
model: The model to check the configuration.

"""
if self.trainer.training:
if self.trainer.state == TrainerState.FITTING:
self.__verify_train_loop_configuration(model)
elif self.trainer.evaluating:
self.__verify_eval_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state == TrainerState.TUNING:
self.__verify_train_loop_configuration(model)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
elif self.trainer.state == TrainerState.VALIDATING:
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state == TrainerState.TESTING:
self.__verify_eval_loop_configuration(model, 'test')
# TODO: add predict

def __verify_train_loop_configuration(self, model):
# -----------------------------------
Expand Down Expand Up @@ -81,11 +88,9 @@ def __verify_train_loop_configuration(self, model):
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
)

def __verify_eval_loop_configuration(self, model):
stage = "val" if self.trainer.validating else "test"

def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None:
loader_name = f'{stage}_dataloader'
step_name = f'{stage}_step'
step_name = 'validation_step' if stage == 'val' else 'test_step'

has_loader = is_overridden(loader_name, model)
has_step = is_overridden(step_name, model)
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa
def attach_dataloaders(
self,
model,
train_dataloader=None,
val_dataloaders=None,
test_dataloaders=None,
predict_dataloaders=None,
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
Expand All @@ -112,7 +112,7 @@ def attach_dataloaders(
if predict_dataloaders is not None:
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)

def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None:
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None:
# We use datamodule if it's been provided, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class RunningStage(LightningEnum):
"""
TRAINING = 'train'
SANITY_CHECKING = 'sanity_check'
VALIDATING = 'validation'
VALIDATING = 'validate'
carmocca marked this conversation as resolved.
Show resolved Hide resolved
TESTING = 'test'
PREDICTING = 'predict'
TUNING = 'tune'
Expand Down
140 changes: 90 additions & 50 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,69 @@ def run_sanity_check(self, ref_model):

self._running_stage = stage

def validate(
self,
model: Optional[LightningModule] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
ckpt_path: Optional[str] = 'best',
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
):
r"""
Perform one evaluation epoch over the validation set.

Args:
model: The model to validate.

val_dataloaders: Either a single PyTorch DataLoader or a list of them,
specifying validation samples.

ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
If ``None``, use the current weights of the model.
carmocca marked this conversation as resolved.
Show resolved Hide resolved
When the model is given as argument, this parameter will not apply.

verbose: If True, prints the validation results.

datamodule: A instance of :class:`LightningDataModule`.

Returns:
The dictionary with final validation results returned by validation_epoch_end.
If validation_epoch_end is not defined, the output is a list of the dictionaries
returned by validation_step.
"""
# --------------------
# SETUP HOOK
# --------------------
self.verbose_evaluate = verbose

self.state = TrainerState.VALIDATING
self.validating = True

# If you supply a datamodule you can't supply val_dataloaders
if val_dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`'
)

model_provided = model is not None
model = model or self.lightning_module

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders)

if not model_provided:
self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# run validate
results = self.fit(model)

assert self.state.stopped
self.validating = False

return results

def test(
self,
model: Optional[LightningModule] = None,
Expand All @@ -833,17 +896,19 @@ def test(
fit to make sure you never run on your test set until you want to.

Args:
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the current weights of the model. Default to ``best``.
datamodule: A instance of :class:`LightningDataModule`.

model: The model to test.

test_dataloaders: Either a single PyTorch DataLoader or a list of them,
specifying test samples.

ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the current weights of the model.
carmocca marked this conversation as resolved.
Show resolved Hide resolved
When the model is given as argument, this parameter will not apply.

verbose: If True, prints the test results.

datamodule: A instance of :class:`LightningDataModule`.

Returns:
Returns a list of dictionaries, one for each test dataloader containing their respective metrics.
"""
Expand All @@ -858,30 +923,33 @@ def test(
# If you supply a datamodule you can't supply test_dataloaders
if test_dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`'
)

model_provided = model is not None
model = model or self.lightning_module

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
results = (
self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else
self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders)
)
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)

if not model_provided:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)

# run test
results = self.fit(model)

assert self.state.stopped
self.testing = False

return results

def __evaluate_using_weights(
def __load_ckpt_weights(
self,
model,
ckpt_path: Optional[str] = None,
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
):
) -> Optional[str]:
# if user requests the best checkpoint but we don't have it, error
if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
raise MisconfigurationException(
Expand All @@ -894,42 +962,18 @@ def __evaluate_using_weights(
if ckpt_path == 'best':
ckpt_path = self.checkpoint_callback.best_model_path

if len(ckpt_path) == 0:
rank_zero_warn(
f'`.test()` found no path for the best weights, {ckpt_path}. Please'
' specify a path for a checkpoint `.test(ckpt_path=PATH)`'
if not ckpt_path:
fn = self.state.value
raise MisconfigurationException(
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)
return {}

self.training_type_plugin.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])

# attach dataloaders
if dataloaders is not None:
self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders)

if self.validating:
self.validated_ckpt_path = ckpt_path
else:
self.tested_ckpt_path = ckpt_path

# run test
results = self.fit(model)

return results

def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None):
# attach data
if dataloaders is not None:
self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders)

# run test
# sets up testing so we short circuit to eval
results = self.fit(model)

return results
return ckpt_path

def predict(
self,
Expand Down Expand Up @@ -970,15 +1014,11 @@ def predict(
'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
)

if datamodule is not None:
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)

# attach data
if dataloaders is not None:
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)

self.model = model
carmocca marked this conversation as resolved.
Show resolved Hide resolved
results = self.fit(model)

assert self.state.stopped
Expand Down
Loading