Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA Memory Profile Analyzer #9860

Closed
wants to merge 13 commits into from
Closed
195 changes: 161 additions & 34 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from os import path
from pathlib import Path
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
from packaging import version

import hydra
import torch
Expand Down Expand Up @@ -49,6 +50,7 @@
from nemo.utils.debug_hook import register_debug_hooks
from nemo.utils.exceptions import NeMoBaseException
from nemo.utils.get_rank import get_rank, is_global_rank_zero
from nemo.utils.memory_profile_analyzer import peak_memory_analysis

__all__ = ['ModelPT']

Expand Down Expand Up @@ -204,6 +206,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

# Setup nsys profiling if it has been enabled in the model config
self._setup_profiling()
# real accurate _batch_idx. We found that the `batch_idx` in `on_train_batch_start` and `on_train_batch_end` has a bug.
self._real_batch_idx = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be avoided by instead using self.trainer.global_step instead of trainer is not None (and it's never none during training step call anyway)


# A flag for the profile generation
self._nsys_profile_started = False
Expand Down Expand Up @@ -1747,7 +1751,9 @@ def _setup_profiling(self):
end_step: 10 # Global batch to end profiling
rank: 0 # Global rank ID to profile
output_path: None # Path to store the profile output file
Minimum Pytorch version for memory_profile: 2.2.0
"""
min_pt_version = "2.2.0"
if self.cfg.get('nsys_profile', None) is not None:
if self.cfg.nsys_profile.get('enabled', False):
# Nsys profiling options
Expand Down Expand Up @@ -1776,36 +1782,87 @@ def _setup_profiling(self):

if self.cfg.get('memory_profile', None) is not None:
if self.cfg.memory_profile.get('enabled', False):
# CUDA memory profiling options
self._memory_profile_enabled = True
self._memory_profile_start_step = self.cfg.memory_profile.get('start_step', 0)
self._memory_profile_end_step = self.cfg.memory_profile.get('end_step', 0)
self._memory_profile_rank = self.cfg.memory_profile.get('rank', 0)
self._memory_profile_output_path = self.cfg.memory_profile.get('output_path', None)

if type(self._memory_profile_start_step) == int:
logging.info(f'Nsys profiling setup with start_step: {self._memory_profile_start_step}')
else:
# guard the memory profile availability by checking the pytorch version. Pytorch version needs to be greater than min_pt_version
if version.parse(torch.__version__) < version.parse(min_pt_version):
raise ValueError(
f'CUDA memory start_step must be of type int. Found: {type(self._memory_profile_start_step)}'
f"Minimum Pytorch version for memory_profile is {min_pt_version}. Found: {torch.__version__}"
)

if type(self._memory_profile_end_step) == int:
logging.info(f'CUDA memory profiling setup with end_step: {self._memory_profile_end_step}')
else:
raise ValueError(
f'CUDA memory end_step must be of type int. Found: {type(self._memory_profile_end_step)}'
# CUDA memory profiling options
self._memory_profile_enabled = True
self._memory_profile_start_step = self.cfg.memory_profile.get('start_step', 0)
self._memory_profile_end_step = self.cfg.memory_profile.get('end_step', 0)
self._memory_profile_rank = self.cfg.memory_profile.get('rank', 0)
self._memory_profile_output_path = self.cfg.memory_profile.get('output_path', None)
self._memory_profile_snapshot_file_activation = (
f'{self._memory_profile_output_path}/memory_profile_rank{self._memory_profile_rank}_act.pickle'
)
self._memory_profile_snapshot_file_weight = (
f'{self._memory_profile_output_path}/memory_profile_rank{self._memory_profile_rank}_weight.pickle'
)
self._memory_profile_snapshot_file_oom = (
f'{self._memory_profile_output_path}/memory_profile_rank{self._memory_profile_rank}_oom.pickle'
)
# CUDA memory profiling options: analysis
self._memory_profile_analysis_enabled = self.cfg.memory_profile.get('analysis_enabled', False)
self._memory_profile_analysis_path = os.path.join(self._memory_profile_output_path, f"analysis")

if self._memory_profile_end_step >= self._memory_profile_start_step:
pass
else:
raise ValueError(f'CUDA memory end_step must be greater than or equal to memory start_step')
logging.info(
f"====== Rank[{self._memory_profile_rank}], Initialization. Start CUDA memory profiling: Weight ======"
)
torch.cuda.memory._record_memory_history()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need specific version of pytorch ? If so please guard it so that it doesn't run on lower pytorch versions


if self._memory_profile_output_path is None or not os.path.isdir(self._memory_profile_output_path):
raise ValueError(
f'Memory profile output path ({self._memory_profile_output_path}) is not set or does not exist.'
torch._C._cuda_attach_out_of_memory_observer(self._oom_observer)

if type(self._memory_profile_start_step) == int:
logging.info(
f'CUDA memory profiling (Activation) setup with start_step: {self._memory_profile_start_step}'
)
else:
raise ValueError(
f'CUDA memory start_step must be of type int. Found: {type(self._memory_profile_start_step)}'
)

if type(self._memory_profile_end_step) == int:
logging.info(
f'CUDA memory profiling (Activation) setup with end_step: {self._memory_profile_end_step}'
)
else:
raise ValueError(
f'CUDA memory end_step must be of type int. Found: {type(self._memory_profile_end_step)}'
)

if self._memory_profile_end_step >= self._memory_profile_start_step:
pass
else:
raise ValueError(
f'CUDA memory (Activation) end_step must be greater than or equal to memory start_step'
)

if self._memory_profile_output_path is None or not os.path.isdir(self._memory_profile_output_path):
raise ValueError(
f'Memory profile output path ({self._memory_profile_output_path}) is not set or does not exist.'
)

def _oom_observer(self, device, alloc, device_alloc, device_free, *args, **kwargs):
if get_rank() == self._memory_profile_rank:
logging.info(
f"====== Rank[{self._memory_profile_rank}]. OOM Profile. End CUDA memory profiling: Out Of Memory. Snapshot saved in {self._memory_profile_snapshot_file_oom} ======"
)
torch.cuda.memory._dump_snapshot(f"{self._memory_profile_snapshot_file_oom}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard if necessary

# if snapshot exists, we call the peak-memory-analyzer and export the csv file
# Need to wait a bit after error shows up. It is running the function.
if self._memory_profile_analysis_enabled:
if os.path.exists(self._memory_profile_snapshot_file_oom):
logging.info(f"===== Memory Profile Analysis: OOM ======")
peak_memory_analysis(
self._memory_profile_snapshot_file_oom,
self._memory_profile_analysis_path,
'oom',
self._memory_profile_rank,
)
else:
raise Exception(f"Snapshot file not found: {self._memory_profile_snapshot_file_oom}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this after line 1833 torch.cuda.memory._dump_snapshot?


def on_train_start(self):
"""PyTorch Lightning hook:
Expand All @@ -1831,24 +1888,74 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> O
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-start
We use it here to enable nsys profiling and dynamic freezing.
"""

# nsys profiling
if self.device.type == 'cuda':
if hasattr(self, '_nsys_profile_enabled'):
if self._nsys_profile_enabled and not self._nsys_profile_started:
if batch_idx >= self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks:
if (
self._real_batch_idx >= self._nsys_profile_start_step
and get_rank() in self._nsys_profile_ranks
):
logging.info("====== Start nsys profiling ======")
torch.cuda.cudart().cudaProfilerStart()
if self._nsys_profile_gen_shape:
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
self._nsys_profile_started = True

if hasattr(self, '_memory_profile_enabled'):
if self._memory_profile_enabled:
if (
(self._real_batch_idx == 0)
and (get_rank() == self._memory_profile_rank)
): # before batch-0 start
try:
logging.info(
f"====== Rank[{self._memory_profile_rank}], Batch[{batch_idx}]. End CUDA memory profiling: Weight. Snapshot saved in {self._memory_profile_snapshot_file_weight} ======"
)
torch.cuda.memory._dump_snapshot(f"{self._memory_profile_snapshot_file_weight}")
# torch.cuda.memory._record_memory_history(enabled=None)
except Exception as e:
logging.error(f"Failed to capture memory snapshot {e}")
return

# Call the analysis function
if self._memory_profile_analysis_enabled:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need this if ? I would assume _memory_profile_analysis_enabled does not change value and in line 1880 you already check whether it's true or not.

if not os.path.exists(self._memory_profile_analysis_path):
os.makedirs(self._memory_profile_analysis_path)
if os.path.exists(self._memory_profile_snapshot_file_weight):
logging.info(f"====== Memory Profile Analysis: Weight ======")
peak_memory_analysis(
self._memory_profile_snapshot_file_weight,
self._memory_profile_analysis_path,
'weight',
self._memory_profile_rank,
)
else:
raise Exception(
f"Snapshot file not found: {self._memory_profile_snapshot_file_weight}"
)

if self._memory_profile_enabled and not self._memory_profile_started:
if batch_idx >= self._memory_profile_start_step and get_rank() == self._memory_profile_rank:
logging.info("====== Start CUDA memory profiling ======")
torch.cuda.memory._record_memory_history(max_entries=100000)
if (
self._real_batch_idx == self._memory_profile_start_step
and get_rank() == self._memory_profile_rank
):
logging.info(
f"====== Rank[{self._memory_profile_rank}], Batch[{batch_idx}]. Start CUDA memory profiling: Activation ======"
)
# Record the current allocated_memory before this training batch.
self._memory_profile_allocated = torch.cuda.memory_allocated() / (
1024 * 1024 * 1024
) # should be weight_memory. # in GB.
# Restart recording memory history
torch.cuda.memory._record_memory_history(enabled=None)
torch.cuda.memory._record_memory_history(
max_entries=30000000
) # we set a very large max_entries to avoid the truncation that we don't want. Normally a few global batches won't exceed this restriction. This is only to avoid the extremely large file.
self._memory_profile_started = True
logging.info(
f"Before recording activation memory snapshot, allocated memory: {self._memory_profile_allocated} GB"
)

# dynamic freezing
if hasattr(self, '_freeze_cfg') and self._freeze_cfg is not None:
Expand Down Expand Up @@ -1877,24 +1984,44 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int =
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-end
We use it here to enable nsys profiling.
"""

if self.device.type == 'cuda':
if hasattr(self, '_nsys_profile_enabled'):
if self._nsys_profile_enabled and not self._nsys_profile_complete:
if batch_idx >= self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks:
if self._real_batch_idx >= self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks:
logging.info("====== End nsys profiling ======")
torch.cuda.cudart().cudaProfilerStop()
self._nsys_profile_complete = True

if hasattr(self, '_memory_profile_enabled'):
if self._memory_profile_enabled and not self._memory_profile_complete:
if batch_idx >= self._memory_profile_end_step and get_rank() == self._memory_profile_rank:
logging.info("====== End CUDA memory profiling ======")
torch.cuda.memory._dump_snapshot(
f'{self._memory_profile_output_path}/memory_profile_rank{self._memory_profile_rank}.pickle'
if (
self._real_batch_idx == self._memory_profile_end_step
and get_rank() == self._memory_profile_rank
):
logging.info(
f"====== Rank[{self._memory_profile_rank}], Batch[{batch_idx}]. End CUDA memory profiling: Activation. Snapshot saved in {self._memory_profile_snapshot_file_activation} ======"
)
torch.cuda.memory._dump_snapshot(self._memory_profile_snapshot_file_activation)
torch.cuda.memory._record_memory_history(enabled=None)
self._memory_profile_complete = True
# Call the analysis function
if self._memory_profile_analysis_enabled and self._memory_profile_complete:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as previously, self._memory_profile_analysis_enabled should be true? due to line 1890

if not os.path.exists(self._memory_profile_analysis_path):
os.makedirs(self._memory_profile_analysis_path)
if os.path.exists(self._memory_profile_snapshot_file_activation):
logging.info(f"====== Memory Profile Analysis: Activation ======")
peak_memory_analysis(
self._memory_profile_snapshot_file_activation,
self._memory_profile_analysis_path,
'act',
self._memory_profile_rank,
)
else:
raise Exception(
f"Snapshot file not found: {self._memory_profile_snapshot_file_activation}"
)
# increase the batch_idx
self._real_batch_idx += 1

def _cleanup_on_execution_end(self):
"""
Expand Down
Loading
Loading