Skip to content

Commit

Permalink
[FEAT] logging refactors 1/n (#4439)
Browse files Browse the repository at this point in the history
* introducing new logging object

* typo

* typo

* Update pytorch_lightning/trainer/logging.py

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* Update pytorch_lightning/trainer/logging.py

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* update on comments

* update on comments

* add more doctstring

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

* resolve on comments

* solve pyright

* Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* update on comments

* Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

* update on comments

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
  • Loading branch information
3 people authored Nov 2, 2020
1 parent 19187d3 commit ac3f739
Show file tree
Hide file tree
Showing 11 changed files with 1,229 additions and 69 deletions.
29 changes: 22 additions & 7 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import collections
import copy
import inspect
import os
import re
import tempfile
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping
Expand All @@ -28,16 +27,17 @@
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
AttributeDict,
collect_init_args,
get_init_args,
)
from pytorch_lightning.callbacks import Callback
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -111,6 +111,8 @@ def __init__(self, *args, **kwargs):
self._datamodule = None
self._results: Optional[Result] = None
self._current_fx_name = ''
self._current_hook_fx_name = None
self._current_dataloader_idx = None

def optimizers(self):
opts = self.trainer.optimizers
Expand Down Expand Up @@ -244,6 +246,18 @@ def log(
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)

if self._current_hook_fx_name is not None:
self.trainer.logger_connector.check_logging_in_callbacks(
self._current_hook_fx_name,
on_step=on_step,
on_epoch=on_epoch
)

# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
raise MisconfigurationException(
f"Logged key: {name} should not contain information about dataloader_idx.")

self._results.log(
name,
value,
Expand All @@ -257,7 +271,8 @@ def log(
enable_graph,
sync_dist,
sync_dist_op,
sync_dist_group
sync_dist_group,
self._current_dataloader_idx,
)

def log_dict(
Expand Down Expand Up @@ -1277,11 +1292,11 @@ def tbptt_split_batch(self, batch, split_size):
batch_split = []
for i, x in enumerate(batch):
if isinstance(x, torch.Tensor):
split_x = x[:, t : t + split_size]
split_x = x[:, t: t + split_size]
elif isinstance(x, collections.Sequence):
split_x = [None] * len(x)
for batch_idx in range(len(x)):
split_x[batch_idx] = x[batch_idx][t : t + split_size]
split_x[batch_idx] = x[batch_idx][t: t + split_size]

batch_split.append(split_x)

Expand Down
77 changes: 56 additions & 21 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def log(
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
dataloader_idx: Optional[int] = None,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
Expand All @@ -144,6 +145,7 @@ def log(

# set step version
step_name = f'{name}_step'

self.__set_meta(
step_name,
value,
Expand All @@ -154,12 +156,15 @@ def log(
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False
forked=False,
dataloader_idx=dataloader_idx,
)

self.__setitem__(step_name, value)

# set epoch version
epoch_name = f'{name}_epoch'

self.__set_meta(
epoch_name,
value,
Expand All @@ -170,7 +175,8 @@ def log(
reduce_fx=reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=False
forked=False,
dataloader_idx=dataloader_idx,
)
self.__setitem__(epoch_name, value)

Expand All @@ -185,7 +191,8 @@ def log(
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=was_forked
forked=was_forked,
dataloader_idx=dataloader_idx,
)

# set the value
Expand All @@ -202,7 +209,8 @@ def __set_meta(
reduce_fx: Callable,
tbptt_pad_token: int,
tbptt_reduce_fx: Callable,
forked: bool
forked: bool,
dataloader_idx: Union[int, None]
):
# set the meta for the item
meta_value = value
Expand All @@ -215,7 +223,8 @@ def __set_meta(
value=meta_value,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
forked=forked
forked=forked,
dataloader_idx=dataloader_idx,
)

self['meta'][name] = meta
Expand All @@ -225,13 +234,22 @@ def __set_meta(
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)

def track_batch_size(self, batch):
batch_size = Result.extract_batch_size(batch)
Result.attach_batch_size(batch_size, self)

@staticmethod
def extract_batch_size(batch):
try:
batch_size = Result.unpack_batch_size(batch)
except RecursionError as re:
batch_size = 1
return batch_size

meta = self['meta']
meta['_internal']['batch_sizes'].append(batch_size)
@staticmethod
def attach_batch_size(batch_size: Union[int, None], result: 'Result') -> None:
if batch_size is not None:
meta = result['meta']
meta['_internal']['batch_sizes'].append(batch_size)

def get_batch_sizes(self):
meta = self['meta']
Expand All @@ -242,7 +260,12 @@ def get_callback_metrics(self) -> dict:

return result

def get_batch_log_metrics(self, include_forked_originals=True) -> dict:
def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str:
if dataloader_idx is not None and add_dataloader_idx:
return f"{k}/dataloader_idx_{dataloader_idx}"
return k

def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_idx=False) -> dict:
"""
Gets the metrics to log at the end of the batch step
Expand All @@ -257,15 +280,17 @@ def get_batch_log_metrics(self, include_forked_originals=True) -> dict:
if options['forked'] and not include_forked_originals:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['logger'] and options['on_step']:
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache.detach()
result[dl_key] = self[k]._forward_cache.detach()
else:
result[k] = self[k]
result[dl_key] = self[k]

return result

def get_epoch_log_metrics(self) -> dict:
def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict:
"""
Gets the metrics to log at the end of epoch
"""
Expand All @@ -279,19 +304,21 @@ def get_epoch_log_metrics(self) -> dict:
if options['forked']:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['logger'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute().detach()
result[dl_key] = self[k].compute().detach()
else:
result[k] = self[k]
result[dl_key] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
self[k].compute()

return result

def get_epoch_pbar_metrics(self):
def get_epoch_pbar_metrics(self, add_dataloader_idx=False):
"""
Gets the metrics to log at the end of epoch
"""
Expand All @@ -305,19 +332,21 @@ def get_epoch_pbar_metrics(self):
if options['forked']:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['prog_bar'] and options['on_epoch']:
if isinstance(self[k], Metric):
result[k] = self[k].compute().detach()
result[dl_key] = self[k].compute().detach()
else:
result[k] = self[k]
result[dl_key] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
self[k].compute()

return result

def get_forked_metrics(self):
def get_forked_metrics(self, add_dataloader_idx=False):
"""
Gets the metrics to log at the end of epoch
"""
Expand All @@ -328,12 +357,14 @@ def get_forked_metrics(self):
if k == '_internal':
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['forked']:
result[k] = self[k]
result[dl_key] = self[k]

return result

def get_batch_pbar_metrics(self, include_forked_originals=True):
def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_idx=False):
"""
Gets the metrics to log at the end of the batch step
"""
Expand All @@ -347,11 +378,13 @@ def get_batch_pbar_metrics(self, include_forked_originals=True):
if options['forked'] and not include_forked_originals:
continue

dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['prog_bar'] and options['on_step']:
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache
result[dl_key] = self[k]._forward_cache
else:
result[k] = self[k]
result[dl_key] = self[k]

return result

Expand Down Expand Up @@ -473,6 +506,8 @@ def reduce_on_epoch_end(cls, outputs):
if option['on_epoch']:
fx = option['reduce_fx']
if fx == torch.mean:
if isinstance(result[k], list):
result[k] = torch.tensor(result[k]).float()
try:
reduced_val = weighted_mean(result[k], batch_sizes)
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pytorch_lightning.trainer.connectors.logger_connector.logger_connector import LoggerConnector
Loading

0 comments on commit ac3f739

Please sign in to comment.