Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Logging in train callbacks #4258

Closed
wants to merge 42 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
ac40d1b
release only training callback
tchaton Oct 20, 2020
8231ae3
update for flake8
tchaton Oct 20, 2020
d6eb8d8
resolve `test_multiple_optimizers_manual_return_and_log`
tchaton Oct 20, 2020
7ae6f79
resolve `test_multiple_optimizers_manual_return`
tchaton Oct 20, 2020
9434d11
release only training callback
tchaton Oct 20, 2020
4741258
update for flake8
tchaton Oct 20, 2020
77fae0e
resolve `test_multiple_optimizers_manual_return_and_log`
tchaton Oct 20, 2020
b47b390
resolve `test_multiple_optimizers_manual_return`
tchaton Oct 20, 2020
2a3c72d
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
542d8d3
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
290a160
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
0dfe8c9
Update pytorch_lightning/core/step_result.py
tchaton Oct 21, 2020
104c9f5
remove mixin
tchaton Oct 21, 2020
abe57d3
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 21, 2020
f4e2477
remove explicit mixin
tchaton Oct 21, 2020
f59d10d
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
bddd61d
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
0740f1e
resolve logging bug
tchaton Oct 22, 2020
82fc4fe
repair bug
tchaton Oct 22, 2020
075d5bf
resolve pep8
tchaton Oct 22, 2020
f71e588
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
36dbc96
resolve formatting bug
tchaton Oct 22, 2020
81ca911
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 22, 2020
25242fb
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 22, 2020
1a6b172
check if metric and grad_norm_dic is defined
tchaton Oct 22, 2020
4690ff5
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 22, 2020
c7c1e7d
resolve pep8
tchaton Oct 22, 2020
1dbe60c
resolve typo
tchaton Oct 22, 2020
54e2799
convert metris and grad_norm_dic to dict when None
tchaton Oct 22, 2020
8a8b54a
resolve pep8
tchaton Oct 22, 2020
43da42a
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
e1652bf
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
536f85f
Merge branch 'FEATURE/logging_in_train_callbacks' of https://github.c…
tchaton Oct 23, 2020
18247fa
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
1058e5e
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 23, 2020
6504731
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 24, 2020
e90f2c0
move files ar
tchaton Oct 25, 2020
292af7d
create connector_logger_utils
tchaton Oct 25, 2020
abdbd9f
resolve flake8
tchaton Oct 25, 2020
8d1c924
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 25, 2020
4aa2317
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 25, 2020
e16a358
Merge branch 'master' into FEATURE/logging_in_train_callbacks
tchaton Oct 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import collections
import copy
import inspect
import os
import re
import tempfile
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
AttributeDict,
collect_init_args,
get_init_args,
)
from pytorch_lightning.callbacks import Callback
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -111,6 +111,8 @@ def __init__(self, *args, **kwargs):
self._datamodule = None
self._results: Optional[Result] = None
self._current_fx_name = ''
self._current_hook_fx_name = ''
self._current_dataloader_idx = None

def optimizers(self):
opts = self.trainer.optimizers
Expand Down Expand Up @@ -244,6 +246,17 @@ def log(
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

if self._current_hook_fx_name != '':
self.trainer.logger_connector.callback_logging_validator\
.validate_callback_logging_arguments(self._current_hook_fx_name,
on_step=on_step,
on_epoch=on_epoch)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(
f"Logged key: {name} should not contain information about dataloader_idx.")

self._results.log(
name,
value,
Expand All @@ -257,7 +270,8 @@ def log(
enable_graph,
sync_dist,
sync_dist_op,
sync_dist_group
sync_dist_group,
self._current_dataloader_idx,
)

def log_dict(
Expand Down Expand Up @@ -950,7 +964,8 @@ def configure_optimizers(
- Single optimizer.
- List or Tuple - List of optimizers.
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict).
- Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' key which value is a single LR scheduler or lr_dict.
- Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' key which value is a single LR
scheduler or lr_dict.
- Tuple of dictionaries as described, with an optional 'frequency' key.
- None - Fit will run without any optimizer.

Expand Down Expand Up @@ -1278,11 +1293,11 @@ def tbptt_split_batch(self, batch, split_size):
batch_split = []
for i, x in enumerate(batch):
if isinstance(x, torch.Tensor):
split_x = x[:, t : t + split_size]
split_x = x[:, t: t + split_size]
elif isinstance(x, collections.Sequence):
split_x = [None] * len(x)
for batch_idx in range(len(x)):
split_x[batch_idx] = x[batch_idx][t : t + split_size]
split_x[batch_idx] = x[batch_idx][t: t + split_size]

batch_split.append(split_x)

Expand Down
19 changes: 14 additions & 5 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def log(
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
dataloader_idx: Optional[int] = None,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
Expand All @@ -144,6 +145,7 @@ def log(

# set step version
step_name = f'{name}_step'

self.__set_meta(
step_name,
value,
Expand All @@ -154,12 +156,15 @@ def log(
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False
forked=False,
dataloader_idx=dataloader_idx,
)

self.__setitem__(step_name, value)

# set epoch version
epoch_name = f'{name}_epoch'

self.__set_meta(
epoch_name,
value,
Expand All @@ -170,7 +175,8 @@ def log(
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False
forked=False,
dataloader_idx=dataloader_idx,
)
self.__setitem__(epoch_name, value)

Expand All @@ -185,7 +191,8 @@ def log(
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=was_forked
forked=was_forked,
dataloader_idx=dataloader_idx,
)

# set the value
Expand All @@ -202,7 +209,8 @@ def __set_meta(
reduce_fx: Callable,
tbptt_pad_token: int,
tbptt_reduce_fx: Callable,
forked: bool
forked: bool,
dataloader_idx: Union[int, None]
):
# set the meta for the item
meta_value = value
Expand All @@ -215,7 +223,8 @@ def __set_meta(
value=meta_value,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=forked
forked=forked,
dataloader_idx=dataloader_idx,
)

self['meta'][name] = meta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -81,5 +82,4 @@ def configure_progress_bar(self, refresh_rate=1, process_position=0):
self.trainer.callbacks.append(progress_bar_callback)
else:
progress_bar_callback = None

return progress_bar_callback
23 changes: 21 additions & 2 deletions pytorch_lightning/trainer/connectors/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.connectors.logger_connector_utils import LoggingCallbackValidator, CacheInternalMetrics
from pprint import pprint
from typing import Iterable
from copy import deepcopy
from collections import ChainMap
from collections import defaultdict, ChainMap


class LoggerConnector:
Expand All @@ -33,6 +34,12 @@ def __init__(self, trainer):
self.logged_metrics = {}
self.progress_bar_metrics = {}
self.eval_loop_results = []
self.callback_logging_validator = LoggingCallbackValidator()
self.cache_internal_metrics = {"train": CacheInternalMetrics()}

def reset_cache_internal_metrics(self, stage):
assert stage in ["train"], f'Stage {stage} should be within ["train"]'
self.cache_internal_metrics[stage] = CacheInternalMetrics()

def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps):
# logging
Expand Down Expand Up @@ -76,7 +83,8 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
metrics.update(mem_map)

# add norms
metrics.update(grad_norm_dic)
if grad_norm_dic is not None:
metrics.update(grad_norm_dic)

# turn all tensors to scalars
scalar_metrics = self.trainer.metrics_to_scalars(metrics)
Expand Down Expand Up @@ -347,6 +355,13 @@ def log_train_epoch_end_metrics(self,
epoch_log_metrics.update(epoch_end_log_result.get_epoch_log_metrics())
epoch_progress_bar_metrics.update(epoch_end_log_result.get_epoch_pbar_metrics())

cache_internal_epoch_log_metrics = self.trainer.logger_connector\
.cache_internal_metrics["train"].get_as_dict("after_on_batch_end", "epoch_log_metrics")
epoch_log_metrics.update(cache_internal_epoch_log_metrics)

cache_internal_epoch_pbar_metrics = self.trainer.logger_connector\
.cache_internal_metrics["train"].get_as_dict("after_on_batch_end", "epoch_pbar_metrics")
epoch_progress_bar_metrics.update(cache_internal_epoch_pbar_metrics)
# TODO: deprecate 1.0
else:
out = self.__run_legacy_training_epoch_end(
Expand Down Expand Up @@ -532,6 +547,10 @@ def log_train_step_metrics(self, batch_output):
# logs user requested information to logger
metrics = batch_output.batch_log_metrics
grad_norm_dic = batch_output.grad_norm_dic
if metrics is None:
metrics = {}
if grad_norm_dic is None:
grad_norm_dic = {}
if len(metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(metrics, grad_norm_dic)
self.callback_metrics.update(metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from pytorch_lightning.trainer.connectors.logger_connector_utils.cache_metrics import CacheInternalMetrics
from pytorch_lightning.trainer.connectors.logger_connector_utils.callback_logging_validator import LoggingCallbackValidator
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from collections import defaultdict, ChainMap


class CacheInternalMetrics:
"""
This class is an helper to cache model._results logged values before / after entering batch loop.
As on every `run_training_batch`, we apply model._results = Result()
and therefore delete any previously logged values

before_on_batch_start is responsible to catch logging values from `on_start` to `on_batch_start`
after_on_batch_end is responsible to catch logging values from `on_batch_end` to `on_epoch_end`
"""

stages = ["before_on_batch_start", "after_on_batch_end"]

def __init__(self):
self._internal_dict = {stage: defaultdict(list) for stage in self.stages}

def append(self, stage: str, key: str, value) -> None:
assert stage in self.stages, f"Provided stage {stage} should be within {self.stages}"
self._internal_dict[stage][key].append(value)

def get_as_dict(self, stage, key):
_internal_metrics = self.get_as_list(stage, key)
return dict(ChainMap(*_internal_metrics))

def get_as_list(self, stage, key):
assert stage in self.stages, f"Provided stage {stage} should be within {self.stages}"
return self._internal_dict[stage][key]

def __repr__(self):
return self._internal_dict.__repr__()

def update(self, trainer, stage: str) -> None:
"""
This function is used to cache any logged information
between "on_train_start" to "on_train_epoch_start" callback hooks
"""
assert stage in self.stages, f"Provided stage {stage} should be within {self.stages}"
if not trainer.running_sanity_check:
model_ref = trainer.get_model()

# save epoch metrics
self.append(stage, "epoch_log_metrics", model_ref._results.get_epoch_log_metrics())
self.append(stage, "epoch_pbar_metrics", model_ref._results.get_epoch_pbar_metrics())

# save step/batch metrics
self.append(stage, "batch_log_metrics", model_ref._results.get_batch_log_metrics())
self.append(stage, "batch_pbar_metrics", model_ref._results.get_batch_pbar_metrics())
Loading