-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA Memory Profile Analyzer #9860
Changes from all commits
47e61cc
bc1ba35
b46c226
cf14018
684ef24
5935314
52b371d
ba7f13e
65a33ea
e365893
8d33368
2cb2497
165e43e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
from os import path | ||
from pathlib import Path | ||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union | ||
from packaging import version | ||
|
||
import hydra | ||
import torch | ||
|
@@ -49,6 +50,7 @@ | |
from nemo.utils.debug_hook import register_debug_hooks | ||
from nemo.utils.exceptions import NeMoBaseException | ||
from nemo.utils.get_rank import get_rank, is_global_rank_zero | ||
from nemo.utils.memory_profile_analyzer import peak_memory_analysis | ||
|
||
__all__ = ['ModelPT'] | ||
|
||
|
@@ -204,6 +206,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): | |
|
||
# Setup nsys profiling if it has been enabled in the model config | ||
self._setup_profiling() | ||
# real accurate _batch_idx. We found that the `batch_idx` in `on_train_batch_start` and `on_train_batch_end` has a bug. | ||
self._real_batch_idx = 0 | ||
|
||
# A flag for the profile generation | ||
self._nsys_profile_started = False | ||
|
@@ -1747,7 +1751,9 @@ def _setup_profiling(self): | |
end_step: 10 # Global batch to end profiling | ||
rank: 0 # Global rank ID to profile | ||
output_path: None # Path to store the profile output file | ||
Minimum Pytorch version for memory_profile: 2.2.0 | ||
""" | ||
min_pt_version = "2.2.0" | ||
if self.cfg.get('nsys_profile', None) is not None: | ||
if self.cfg.nsys_profile.get('enabled', False): | ||
# Nsys profiling options | ||
|
@@ -1776,36 +1782,87 @@ def _setup_profiling(self): | |
|
||
if self.cfg.get('memory_profile', None) is not None: | ||
if self.cfg.memory_profile.get('enabled', False): | ||
# CUDA memory profiling options | ||
self._memory_profile_enabled = True | ||
self._memory_profile_start_step = self.cfg.memory_profile.get('start_step', 0) | ||
self._memory_profile_end_step = self.cfg.memory_profile.get('end_step', 0) | ||
self._memory_profile_rank = self.cfg.memory_profile.get('rank', 0) | ||
self._memory_profile_output_path = self.cfg.memory_profile.get('output_path', None) | ||
|
||
if type(self._memory_profile_start_step) == int: | ||
logging.info(f'Nsys profiling setup with start_step: {self._memory_profile_start_step}') | ||
else: | ||
# guard the memory profile availability by checking the pytorch version. Pytorch version needs to be greater than min_pt_version | ||
if version.parse(torch.__version__) < version.parse(min_pt_version): | ||
raise ValueError( | ||
f'CUDA memory start_step must be of type int. Found: {type(self._memory_profile_start_step)}' | ||
f"Minimum Pytorch version for memory_profile is {min_pt_version}. Found: {torch.__version__}" | ||
) | ||
|
||
if type(self._memory_profile_end_step) == int: | ||
logging.info(f'CUDA memory profiling setup with end_step: {self._memory_profile_end_step}') | ||
else: | ||
raise ValueError( | ||
f'CUDA memory end_step must be of type int. Found: {type(self._memory_profile_end_step)}' | ||
# CUDA memory profiling options | ||
self._memory_profile_enabled = True | ||
self._memory_profile_start_step = self.cfg.memory_profile.get('start_step', 0) | ||
self._memory_profile_end_step = self.cfg.memory_profile.get('end_step', 0) | ||
self._memory_profile_rank = self.cfg.memory_profile.get('rank', 0) | ||
self._memory_profile_output_path = self.cfg.memory_profile.get('output_path', None) | ||
self._memory_profile_snapshot_file_activation = ( | ||
f'{self._memory_profile_output_path}/memory_profile_rank{self._memory_profile_rank}_act.pickle' | ||
) | ||
self._memory_profile_snapshot_file_weight = ( | ||
f'{self._memory_profile_output_path}/memory_profile_rank{self._memory_profile_rank}_weight.pickle' | ||
) | ||
self._memory_profile_snapshot_file_oom = ( | ||
f'{self._memory_profile_output_path}/memory_profile_rank{self._memory_profile_rank}_oom.pickle' | ||
) | ||
# CUDA memory profiling options: analysis | ||
self._memory_profile_analysis_enabled = self.cfg.memory_profile.get('analysis_enabled', False) | ||
self._memory_profile_analysis_path = os.path.join(self._memory_profile_output_path, f"analysis") | ||
|
||
if self._memory_profile_end_step >= self._memory_profile_start_step: | ||
pass | ||
else: | ||
raise ValueError(f'CUDA memory end_step must be greater than or equal to memory start_step') | ||
logging.info( | ||
f"====== Rank[{self._memory_profile_rank}], Initialization. Start CUDA memory profiling: Weight ======" | ||
) | ||
torch.cuda.memory._record_memory_history() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need specific version of pytorch ? If so please guard it so that it doesn't run on lower pytorch versions |
||
|
||
if self._memory_profile_output_path is None or not os.path.isdir(self._memory_profile_output_path): | ||
raise ValueError( | ||
f'Memory profile output path ({self._memory_profile_output_path}) is not set or does not exist.' | ||
torch._C._cuda_attach_out_of_memory_observer(self._oom_observer) | ||
|
||
if type(self._memory_profile_start_step) == int: | ||
logging.info( | ||
f'CUDA memory profiling (Activation) setup with start_step: {self._memory_profile_start_step}' | ||
) | ||
else: | ||
raise ValueError( | ||
f'CUDA memory start_step must be of type int. Found: {type(self._memory_profile_start_step)}' | ||
) | ||
|
||
if type(self._memory_profile_end_step) == int: | ||
logging.info( | ||
f'CUDA memory profiling (Activation) setup with end_step: {self._memory_profile_end_step}' | ||
) | ||
else: | ||
raise ValueError( | ||
f'CUDA memory end_step must be of type int. Found: {type(self._memory_profile_end_step)}' | ||
) | ||
|
||
if self._memory_profile_end_step >= self._memory_profile_start_step: | ||
pass | ||
else: | ||
raise ValueError( | ||
f'CUDA memory (Activation) end_step must be greater than or equal to memory start_step' | ||
) | ||
|
||
if self._memory_profile_output_path is None or not os.path.isdir(self._memory_profile_output_path): | ||
raise ValueError( | ||
f'Memory profile output path ({self._memory_profile_output_path}) is not set or does not exist.' | ||
) | ||
|
||
def _oom_observer(self, device, alloc, device_alloc, device_free, *args, **kwargs): | ||
if get_rank() == self._memory_profile_rank: | ||
logging.info( | ||
f"====== Rank[{self._memory_profile_rank}]. OOM Profile. End CUDA memory profiling: Out Of Memory. Snapshot saved in {self._memory_profile_snapshot_file_oom} ======" | ||
) | ||
torch.cuda.memory._dump_snapshot(f"{self._memory_profile_snapshot_file_oom}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard if necessary |
||
# if snapshot exists, we call the peak-memory-analyzer and export the csv file | ||
# Need to wait a bit after error shows up. It is running the function. | ||
if self._memory_profile_analysis_enabled: | ||
if os.path.exists(self._memory_profile_snapshot_file_oom): | ||
logging.info(f"===== Memory Profile Analysis: OOM ======") | ||
peak_memory_analysis( | ||
self._memory_profile_snapshot_file_oom, | ||
self._memory_profile_analysis_path, | ||
'oom', | ||
self._memory_profile_rank, | ||
) | ||
else: | ||
raise Exception(f"Snapshot file not found: {self._memory_profile_snapshot_file_oom}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe move this after line 1833 |
||
|
||
def on_train_start(self): | ||
"""PyTorch Lightning hook: | ||
|
@@ -1831,24 +1888,74 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> O | |
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-start | ||
We use it here to enable nsys profiling and dynamic freezing. | ||
""" | ||
|
||
# nsys profiling | ||
if self.device.type == 'cuda': | ||
if hasattr(self, '_nsys_profile_enabled'): | ||
if self._nsys_profile_enabled and not self._nsys_profile_started: | ||
if batch_idx >= self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks: | ||
if ( | ||
self._real_batch_idx >= self._nsys_profile_start_step | ||
and get_rank() in self._nsys_profile_ranks | ||
): | ||
logging.info("====== Start nsys profiling ======") | ||
torch.cuda.cudart().cudaProfilerStart() | ||
if self._nsys_profile_gen_shape: | ||
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() | ||
self._nsys_profile_started = True | ||
|
||
if hasattr(self, '_memory_profile_enabled'): | ||
if self._memory_profile_enabled: | ||
if ( | ||
(self._real_batch_idx == 0) | ||
and (get_rank() == self._memory_profile_rank) | ||
): # before batch-0 start | ||
try: | ||
logging.info( | ||
f"====== Rank[{self._memory_profile_rank}], Batch[{batch_idx}]. End CUDA memory profiling: Weight. Snapshot saved in {self._memory_profile_snapshot_file_weight} ======" | ||
) | ||
torch.cuda.memory._dump_snapshot(f"{self._memory_profile_snapshot_file_weight}") | ||
# torch.cuda.memory._record_memory_history(enabled=None) | ||
except Exception as e: | ||
logging.error(f"Failed to capture memory snapshot {e}") | ||
return | ||
|
||
# Call the analysis function | ||
if self._memory_profile_analysis_enabled: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you need this if ? I would assume |
||
if not os.path.exists(self._memory_profile_analysis_path): | ||
os.makedirs(self._memory_profile_analysis_path) | ||
if os.path.exists(self._memory_profile_snapshot_file_weight): | ||
logging.info(f"====== Memory Profile Analysis: Weight ======") | ||
peak_memory_analysis( | ||
self._memory_profile_snapshot_file_weight, | ||
self._memory_profile_analysis_path, | ||
'weight', | ||
self._memory_profile_rank, | ||
) | ||
else: | ||
raise Exception( | ||
f"Snapshot file not found: {self._memory_profile_snapshot_file_weight}" | ||
) | ||
|
||
if self._memory_profile_enabled and not self._memory_profile_started: | ||
if batch_idx >= self._memory_profile_start_step and get_rank() == self._memory_profile_rank: | ||
logging.info("====== Start CUDA memory profiling ======") | ||
torch.cuda.memory._record_memory_history(max_entries=100000) | ||
if ( | ||
self._real_batch_idx == self._memory_profile_start_step | ||
and get_rank() == self._memory_profile_rank | ||
): | ||
logging.info( | ||
f"====== Rank[{self._memory_profile_rank}], Batch[{batch_idx}]. Start CUDA memory profiling: Activation ======" | ||
) | ||
# Record the current allocated_memory before this training batch. | ||
self._memory_profile_allocated = torch.cuda.memory_allocated() / ( | ||
1024 * 1024 * 1024 | ||
) # should be weight_memory. # in GB. | ||
# Restart recording memory history | ||
torch.cuda.memory._record_memory_history(enabled=None) | ||
torch.cuda.memory._record_memory_history( | ||
max_entries=30000000 | ||
) # we set a very large max_entries to avoid the truncation that we don't want. Normally a few global batches won't exceed this restriction. This is only to avoid the extremely large file. | ||
self._memory_profile_started = True | ||
logging.info( | ||
f"Before recording activation memory snapshot, allocated memory: {self._memory_profile_allocated} GB" | ||
) | ||
|
||
# dynamic freezing | ||
if hasattr(self, '_freeze_cfg') and self._freeze_cfg is not None: | ||
|
@@ -1877,24 +1984,44 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int = | |
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-end | ||
We use it here to enable nsys profiling. | ||
""" | ||
|
||
if self.device.type == 'cuda': | ||
if hasattr(self, '_nsys_profile_enabled'): | ||
if self._nsys_profile_enabled and not self._nsys_profile_complete: | ||
if batch_idx >= self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks: | ||
if self._real_batch_idx >= self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks: | ||
logging.info("====== End nsys profiling ======") | ||
torch.cuda.cudart().cudaProfilerStop() | ||
self._nsys_profile_complete = True | ||
|
||
if hasattr(self, '_memory_profile_enabled'): | ||
if self._memory_profile_enabled and not self._memory_profile_complete: | ||
if batch_idx >= self._memory_profile_end_step and get_rank() == self._memory_profile_rank: | ||
logging.info("====== End CUDA memory profiling ======") | ||
torch.cuda.memory._dump_snapshot( | ||
f'{self._memory_profile_output_path}/memory_profile_rank{self._memory_profile_rank}.pickle' | ||
if ( | ||
self._real_batch_idx == self._memory_profile_end_step | ||
and get_rank() == self._memory_profile_rank | ||
): | ||
logging.info( | ||
f"====== Rank[{self._memory_profile_rank}], Batch[{batch_idx}]. End CUDA memory profiling: Activation. Snapshot saved in {self._memory_profile_snapshot_file_activation} ======" | ||
) | ||
torch.cuda.memory._dump_snapshot(self._memory_profile_snapshot_file_activation) | ||
torch.cuda.memory._record_memory_history(enabled=None) | ||
self._memory_profile_complete = True | ||
# Call the analysis function | ||
if self._memory_profile_analysis_enabled and self._memory_profile_complete: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as previously, |
||
if not os.path.exists(self._memory_profile_analysis_path): | ||
os.makedirs(self._memory_profile_analysis_path) | ||
if os.path.exists(self._memory_profile_snapshot_file_activation): | ||
logging.info(f"====== Memory Profile Analysis: Activation ======") | ||
peak_memory_analysis( | ||
self._memory_profile_snapshot_file_activation, | ||
self._memory_profile_analysis_path, | ||
'act', | ||
self._memory_profile_rank, | ||
) | ||
else: | ||
raise Exception( | ||
f"Snapshot file not found: {self._memory_profile_snapshot_file_activation}" | ||
) | ||
# increase the batch_idx | ||
self._real_batch_idx += 1 | ||
|
||
def _cleanup_on_execution_end(self): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be avoided by instead using self.trainer.global_step instead of trainer is not None (and it's never none during training step call anyway)