From 806d8bba1728927fd456ffdee1fba3946e2b04d3 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Thu, 2 May 2024 10:14:25 -0700 Subject: [PATCH] CUDA memory profile Signed-off-by: Sangkug Lym --- .../clip/megatron_clip_models.py | 10 ++- .../language_modeling/megatron_bert_model.py | 10 ++- .../language_modeling/megatron_gpt_model.py | 10 ++- .../megatron_gpt_sft_model.py | 3 + .../megatron_vit_classification_models.py | 10 ++- nemo/core/classes/modelPT.py | 73 +++++++++++++++++-- 6 files changed, 98 insertions(+), 18 deletions(-) diff --git a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py index fe35ae148026a..7be7407b98ae0 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py @@ -358,12 +358,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.transformer_engine = cfg.get('transformer_engine', False) # Convert the global-batch-based profile index to micro-batch index - if hasattr(self, '_nsys_profile_enabled'): + if hasattr(self, '_nsys_profile_enabled') or hasattr(self, '_memory_profile_enabled'): mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1) data_parallel_world_size = trainer.world_size // mp_size grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size) - self._nsys_profile_start_step *= grad_accum_steps - self._nsys_profile_end_step *= grad_accum_steps + if hasattr(self, '_nsys_profile_enabled'): + self._nsys_profile_start_step *= grad_accum_steps + self._nsys_profile_end_step *= grad_accum_steps + if hasattr(self, '_memory_profile_enabled'): + self._memory_profile_start_step *= grad_accum_steps + self._memory_profile_end_step *= grad_accum_steps self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True) self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index 0f1fa76f9b019..984fca5f12591 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -136,12 +136,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # Model wrapper to convert both model and inputs to half precision self._wrap_model_for_O2() - if hasattr(self, '_nsys_profile_enabled'): + if hasattr(self, '_nsys_profile_enabled') or hasattr(self, '_memory_profile_enabled'): mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1) data_parallel_world_size = trainer.world_size // mp_size grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size) - self._nsys_profile_start_step *= grad_accum_steps - self._nsys_profile_end_step *= grad_accum_steps + if hasattr(self, '_nsys_profile_enabled'): + self._nsys_profile_start_step *= grad_accum_steps + self._nsys_profile_end_step *= grad_accum_steps + if hasattr(self, '_memory_profile_enabled'): + self._memory_profile_start_step *= grad_accum_steps + self._memory_profile_end_step *= grad_accum_steps def model_provider_func(self, pre_process, post_process): cfg = self.cfg diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index d7f489abf1584..e2124a38b1c6f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -354,13 +354,17 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self._inference_config = None # Convert the global-batch-based profile index to micro-batch index - if hasattr(self, '_nsys_profile_enabled'): + if hasattr(self, '_nsys_profile_enabled') or hasattr(self, '_memory_profile_enabled'): mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1) cp_size = cfg.get('context_parallel_size', 1) data_parallel_world_size = trainer.world_size // (mp_size * cp_size) grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size) - self._nsys_profile_start_step *= grad_accum_steps - self._nsys_profile_end_step *= grad_accum_steps + if hasattr(self, '_nsys_profile_enabled'): + self._nsys_profile_start_step *= grad_accum_steps + self._nsys_profile_end_step *= grad_accum_steps + if hasattr(self, '_memory_profile_enabled'): + self._memory_profile_start_step *= grad_accum_steps + self._memory_profile_end_step *= grad_accum_steps self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True) self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 892a87189880d..1d76013c37172 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -94,6 +94,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): if hasattr(self, '_nsys_profile_enabled'): self._nsys_profile_start_step = self.cfg.nsys_profile.get('start_step', 0) self._nsys_profile_end_step = self.cfg.nsys_profile.get('end_step', 0) + if hasattr(self, '_memory_profile_enabled'): + self._memory_profile_start_step = self.cfg.memory_profile.get('start_step', 0) + self._memory_profile_end_step = self.cfg.memory_profile.get('end_step', 0) self.virtual_tokens = 0 self.init_global_step = 0 diff --git a/nemo/collections/vision/models/megatron_vit_classification_models.py b/nemo/collections/vision/models/megatron_vit_classification_models.py index ea6d3578c540b..46788d2c882c2 100644 --- a/nemo/collections/vision/models/megatron_vit_classification_models.py +++ b/nemo/collections/vision/models/megatron_vit_classification_models.py @@ -181,12 +181,16 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.transformer_engine = cfg.get('transformer_engine', False) # Convert the global-batch-based profile index to micro-batch index - if hasattr(self, '_nsys_profile_enabled'): + if hasattr(self, '_nsys_profile_enabled') or hasattr(self, '_memory_profile_enabled'): mp_size = cfg.get('tensor_model_parallel_size', 1) * cfg.get('pipeline_model_parallel_size', 1) data_parallel_world_size = trainer.world_size // mp_size grad_accum_steps = cfg.get('global_batch_size') // (cfg.get('micro_batch_size') * data_parallel_world_size) - self._nsys_profile_start_step *= grad_accum_steps - self._nsys_profile_end_step *= grad_accum_steps + if hasattr(self, '_nsys_profile_enabled'): + self._nsys_profile_start_step *= grad_accum_steps + self._nsys_profile_end_step *= grad_accum_steps + if hasattr(self, '_memory_profile_enabled'): + self._memory_profile_start_step *= grad_accum_steps + self._memory_profile_end_step *= grad_accum_steps self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True) self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False) diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index d5cd18179e8b7..53f3d745feb47 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -192,10 +192,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.training_step = model_utils.wrap_training_step(self.training_step) # Setup nsys profiling if it has been enabled in the model config - self._setup_nsys_profiling() + self._setup_profiling() # A flag for the profile generation - self._profile_complete = False + self._nsys_profile_started = False + self._nsys_profile_complete = False + self._memory_profile_started = False + self._memory_profile_complete = False def __init_subclass__(cls) -> None: cls._save_restore_connector = SaveRestoreConnector() @@ -1664,7 +1667,7 @@ def update_save_restore_connector(cls, save_restore_connector): else: setattr(cls, '_save_restore_connector', save_restore_connector) - def _setup_nsys_profiling(self): + def _setup_profiling(self): """ Enables nsys profiling To use, add the following optoins to the model config: ## Nsys profiling options @@ -1676,6 +1679,15 @@ def _setup_nsys_profiling(self): And then wrap the model training script with: nsys profile -s none -o -t cuda,nvtx --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop python ./examples/... See more options at: https://docs.nvidia.com/nsight-systems/UserGuide/index.html#cli-profiling + + Enables CUDA memory profiling + To use, add the following optoins to the model config: + ## CUDA memory profiling options + memory_profile: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + rank: 0 # Global rank ID to profile + output_path: None # Path to store the profile output file """ if self.cfg.get('nsys_profile', None) is not None: if self.cfg.nsys_profile.get('enabled', False): @@ -1703,6 +1715,37 @@ def _setup_nsys_profiling(self): else: raise ValueError(f'Nsys end_step must be greater than or equal to nsys start_step') + if self.cfg.get('memory_profile', None) is not None: + if self.cfg.memory_profile.get('enabled', False): + # CUDA memory profiling options + self._memory_profile_enabled = True + self._memory_profile_start_step = self.cfg.memory_profile.get('start_step', 0) + self._memory_profile_end_step = self.cfg.memory_profile.get('end_step', 0) + self._memory_profile_rank = self.cfg.memory_profile.get('rank', 0) + self._memory_profile_output_path = self.cfg.memory_profile.get('output_path', None) + + if type(self._memory_profile_start_step) == int: + logging.info(f'Nsys profiling setup with start_step: {self._memory_profile_start_step}') + else: + raise ValueError( + f'CUDA memory start_step must be of type int. Found: {type(self._memory_profile_start_step)}' + ) + + if type(self._memory_profile_end_step) == int: + logging.info(f'CUDA memory profiling setup with end_step: {self._memory_profile_end_step}') + else: + raise ValueError(f'CUDA memory end_step must be of type int. Found: {type(self._memory_profile_end_step)}') + + if self._memory_profile_end_step >= self._memory_profile_start_step: + pass + else: + raise ValueError(f'CUDA memory end_step must be greater than or equal to memory start_step') + + if self._memory_profile_output_path == None or not os.path.isdir(self._memory_profile_output_path): + raise ValueError( + f'Memory profile output path ({self._memory_profile_output_path}) is not set or does not exist.' + ) + def on_train_start(self): """ PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-start @@ -1731,12 +1774,20 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> O # nsys profiling if self.device.type == 'cuda': if hasattr(self, '_nsys_profile_enabled'): - if self._nsys_profile_enabled and not self._profile_complete: + if self._nsys_profile_enabled and not self._nsys_profile_started: if batch_idx >= self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks: logging.info("====== Start nsys profiling ======") torch.cuda.cudart().cudaProfilerStart() if self._nsys_profile_gen_shape: torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() + self._nsys_profile_started = True + + if hasattr(self, '_memory_profile_enabled'): + if self._memory_profile_enabled and not self._memory_profile_started: + if batch_idx >= self._memory_profile_start_step and get_rank() == self._memory_profile_rank: + logging.info("====== Start CUDA memory profiling ======") + torch.cuda.memory._record_memory_history(max_entries=100000) + self._memory_profile_started = True # dynamic freezing if hasattr(self, '_freeze_cfg') and self._freeze_cfg is not None: @@ -1768,11 +1819,21 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int = if self.device.type == 'cuda': if hasattr(self, '_nsys_profile_enabled'): - if self._nsys_profile_enabled and not self._profile_complete: + if self._nsys_profile_enabled and not self._nsys_profile_complete: if batch_idx >= self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks: logging.info("====== End nsys profiling ======") torch.cuda.cudart().cudaProfilerStop() - self._profile_complete = True + self._nsys_profile_complete = True + + if hasattr(self, '_memory_profile_enabled'): + if self._memory_profile_enabled and not self._memory_profile_complete: + if batch_idx >= self._memory_profile_end_step and get_rank() == self._memory_profile_rank: + logging.info("====== End CUDA memory profiling ======") + torch.cuda.memory._dump_snapshot( + f'{self._memory_profile_output_path}/memory_profile_rank{self._memory_profile_rank}.pickle' + ) + torch.cuda.memory._record_memory_history(enabled=None) + self._memory_profile_complete = True def _cleanup_on_execution_end(self): """