Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profiler summary #1259

Merged
merged 8 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))

### Changed
Expand Down Expand Up @@ -72,7 +73,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
- Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941))
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
- Added support for logging `hparams` as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
- Checkpoint and early stopping now work without val. step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))
- Support graceful training cleanup after Keyboard Interrupt ([#856](https://github.com/PyTorchLightning/pytorch-lightning/pull/856), [#1019](https://github.com/PyTorchLightning/pytorch-lightning/pull/1019))
- Added type hints for function arguments ([#912](https://github.com/PyTorchLightning/pytorch-lightning/pull/912), )
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
import torch_xla.core.xla_model as xm
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class CometLogger(LightningLoggerBase):
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/profiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


Built-in checks
----------------
---------------

PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:

Expand All @@ -20,7 +20,7 @@
- on_training_end

Enable simple profiling
-------------------------
-----------------------

If you only wish to profile the standard actions, you can set `profiler=True` when constructing
your `Trainer` object.
Expand Down Expand Up @@ -113,10 +113,11 @@ def custom_processing_step(self, data):

"""

from pytorch_lightning.profiler.profiler import Profiler, AdvancedProfiler, PassThroughProfiler
from pytorch_lightning.profiler.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler

__all__ = [
'Profiler',
'BaseProfiler',
'SimpleProfiler',
'AdvancedProfiler',
'PassThroughProfiler',
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import cProfile
import io
import os
import pstats
import time
from abc import ABC, abstractmethod
Expand All @@ -16,6 +17,18 @@ class BaseProfiler(ABC):
If you wish to write a custom profiler, you should inhereit from this class.
"""

def __init__(self, output_streams: list = None):
"""
Params:
stream_out: callable
"""
if output_streams:
if not isinstance(output_streams, (list, tuple)):
output_streams = [output_streams]
else:
output_streams = []
self.write_streams = output_streams

@abstractmethod
def start(self, action_name: str) -> None:
"""Defines how to start recording an action."""
Expand Down Expand Up @@ -57,7 +70,12 @@ def profile_iterable(self, iterable, action_name: str) -> None:

def describe(self) -> None:
"""Logs a profile report after the conclusion of the training run."""
pass
for write in self.write_streams:
write(self.summary())

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""


class PassThroughProfiler(BaseProfiler):
Expand All @@ -67,25 +85,39 @@ class PassThroughProfiler(BaseProfiler):
"""

def __init__(self):
pass
super().__init__(output_streams=None)

def start(self, action_name: str) -> None:
pass

def stop(self, action_name: str) -> None:
pass

def summary(self) -> str:
return ""


class Profiler(BaseProfiler):
class SimpleProfiler(BaseProfiler):
"""
This profiler simply records the duration of actions (in seconds) and reports
the mean duration of each action and the total time spent over the entire training run.
"""

def __init__(self):
def __init__(self, output_filename: str = None):
"""
Params:
output_filename (str): optionally save profile results to file instead of printing
to std out when training is finished.
"""
self.current_actions = {}
self.recorded_durations = defaultdict(list)

self.output_fname = output_filename
self.output_file = open(self.output_fname, 'w') if self.output_fname else None

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
if action_name in self.current_actions:
raise ValueError(
Expand All @@ -103,20 +135,31 @@ def stop(self, action_name: str) -> None:
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)

def describe(self) -> None:
def summary(self) -> str:
output_string = "\n\nProfiler Report\n"

def log_row(action, mean, total):
return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}"
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"

output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"\n{'-' * 65}"
output_string += f"{os.linesep}{'-' * 65}"
for action, durations in self.recorded_durations.items():
output_string += log_row(
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}",
)
output_string += "\n"
log.info(output_string)
output_string += os.linesep
return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()


class AdvancedProfiler(BaseProfiler):
Expand All @@ -136,9 +179,14 @@ def __init__(self, output_filename: str = None, line_count_restriction: float =
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
"""
self.profiled_actions = {}
self.output_filename = output_filename
self.line_count_restriction = line_count_restriction

self.output_fname = output_filename
self.output_file = open(self.output_fname, 'w') if self.output_fname else None

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
self.profiled_actions[action_name] = cProfile.Profile()
Expand All @@ -152,22 +200,28 @@ def stop(self, action_name: str) -> None:
)
pr.disable()

def describe(self) -> None:
self.recorded_stats = {}
def summary(self) -> str:
recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
ps.print_stats(self.line_count_restriction)
self.recorded_stats[action_name] = s.getvalue()
if self.output_filename is not None:
# save to file
with open(self.output_filename, "w") as f:
for action, stats in self.recorded_stats.items():
f.write(f"Profile stats for: {action}")
f.write(stats)
else:
# log to standard out
output_string = "\nProfiler Report\n"
for action, stats in self.recorded_stats.items():
output_string += f"\nProfile stats for: {action}\n{stats}"
log.info(output_string)
recorded_stats[action_name] = s.getvalue()

# log to standard out
output_string = f"{os.linesep}Profiler Report{os.linesep}"
for action, stats in recorded_stats.items():
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
from apex import amp
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def train_fx(trial_hparams, cluster_manager, _):
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
from apex import amp
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
from apex import amp
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
import torch_xla.distributed.parallel_loader as xla_pl
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
from pytorch_lightning.profiler.profiler import BaseProfiler
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
Expand All @@ -33,7 +32,7 @@
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean

try:
Expand Down Expand Up @@ -364,7 +363,7 @@ def __init__(

# configure profiler
if profiler is True:
profiler = Profiler()
profiler = SimpleProfiler()
self.profiler = profiler or PassThroughProfiler()

# configure early stop callback
Expand Down Expand Up @@ -490,10 +489,10 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
('print_nan_grads', (<class 'bool'>,), False),
('process_position', (<class 'int'>,), 0),
('profiler',
(<class 'pytorch_lightning.profiler.profiler.BaseProfiler'>,
(<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
<class 'NoneType'>),
None),
...
...
"""
trainer_default_params = inspect.signature(cls).parameters
name_type_default = []
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean

try:
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel


Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
LightningTestModel,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
parse_gpu_ids,
determine_root_gpu_device,
)
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel

PRETEND_N_OF_GPUS = 16
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
LightningTestModel,
LightningTestModelWithoutHyperparametersArg,
Expand Down
2 changes: 2 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def test_tbd_remove_in_v0_9_0_module_imports():
from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402

from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler # noqa: F402


class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):

Expand Down
Loading