Skip to content

Commit

Permalink
Enable dumping raw prof files in AdvancedProfiler (#19703)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
  • Loading branch information
clumsy and azzhipa authored Jul 15, 2024
1 parent 2dc9c3d commit 74470a6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Made saving non-distributed checkpoints fully atomic ([#20011](https://github.com/Lightning-AI/pytorch-lightning/pull/20011))

- Added `dump_stats` flag to `AdvancedProfiler` ([#19703](https://github.com/Lightning-AI/pytorch-lightning/issues/19703))

-

### Changed
Expand Down
25 changes: 25 additions & 0 deletions src/lightning/pytorch/profilers/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
import cProfile
import io
import logging
import os
import pstats
import tempfile
from pathlib import Path
from typing import Dict, Optional, Tuple, Union

from typing_extensions import override

from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.pytorch.profilers.profiler import Profiler
from lightning.pytorch.utilities.rank_zero import rank_zero_only

log = logging.getLogger(__name__)

Expand All @@ -40,6 +44,7 @@ def __init__(
dirpath: Optional[Union[str, Path]] = None,
filename: Optional[str] = None,
line_count_restriction: float = 1.0,
dump_stats: bool = False,
) -> None:
"""
Args:
Expand All @@ -54,13 +59,16 @@ def __init__(
reported for each action. either an integer (to select a count of lines),
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
dump_stats: Whether to save raw profiler results. When ``True`` then ``dirpath`` must be provided.
Raises:
ValueError:
If you attempt to stop recording an action which was never started.
"""
super().__init__(dirpath=dirpath, filename=filename)
self.profiled_actions: Dict[str, cProfile.Profile] = {}
self.line_count_restriction = line_count_restriction
self.dump_stats = dump_stats

@override
def start(self, action_name: str) -> None:
Expand All @@ -75,10 +83,27 @@ def stop(self, action_name: str) -> None:
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
pr.disable()

def _dump_stats(self, action_name: str, profile: cProfile.Profile) -> None:
assert self.dirpath
dst_filepath = os.path.join(self.dirpath, self._prepare_filename(action_name=action_name, extension=".prof"))
dst_fs = get_filesystem(dst_filepath)
dst_fs.mkdirs(self.dirpath, exist_ok=True)
# temporarily save to local since pstats can only dump into a local file
with tempfile.TemporaryDirectory(
prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()
) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file:
src_filepath = os.path.join(tmp_dir, "tmp.prof")
profile.dump_stats(src_filepath)
src_fs = get_filesystem(src_filepath)
with src_fs.open(src_filepath, "rb") as src_file:
dst_file.write(src_file.read())

@override
def summary(self) -> str:
recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
if self.dump_stats:
self._dump_stats(action_name, pr)
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("cumulative")
ps.print_stats(self.line_count_restriction)
Expand Down
13 changes: 13 additions & 0 deletions tests/tests_pytorch/profilers/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,19 @@ def test_advanced_profiler_describe(tmp_path, advanced_profiler):
assert len(data) > 0


def test_advanced_profiler_dump_states(tmp_path):
advanced_profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True)
"""Ensure the profiler dump stats during summary."""
# record at least one event
with advanced_profiler.profile(action_name := "test"):
pass
# dump_stats to file
advanced_profiler.describe()
path = advanced_profiler.dirpath / f"{action_name}.prof"
data = path.read_bytes()
assert len(data) > 0


def test_advanced_profiler_value_errors(advanced_profiler):
"""Ensure errors are raised where expected."""
action = "test"
Expand Down

0 comments on commit 74470a6

Please sign in to comment.