From cd4649534391d0e23d0455bd434ea5ded49df465 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 2 Apr 2024 15:41:08 -0700 Subject: [PATCH 1/2] Make sure embedding grads are reduced in FP32 Signed-off-by: Tim Moon --- .../language_modeling/megatron_base_model.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 85867df672f2..6b1dd3719d61 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -766,8 +766,28 @@ def get_config_arg(key: str, default_value: Optional[Any] = None) -> Any: # Make sure embedding grad reductions are in FP32 if optim_dtype == torch.float32: - for name, param in self.named_parameters(): - if 'word_embedding' in name or 'position_embedding' in name or 'output_layer' in name: + fp32_params = [] + modules = self.get_model_module_list() + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + if self.mcore_gpt: + fp32_params.append(modules[0].shared_embedding_or_output_weight()) + fp32_params.append(modules[0].language_model.embedding.position_embeddings.weight) + else: + fp32_params.append(modules[0].word_embeddings_weight()) + fp32_params.append(modules[0].position_embeddings_weight()) + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + share_embeddings_and_output_weights = ( + modules[-1].share_embeddings_and_output_weights + if self.mcore_gpt + else modules[-1].share_token_embeddings + ) + if share_embeddings_and_output_weights: + if self.mcore_gpt: + fp32_params.append(modules[-1].shared_embedding_or_output_weight()) + else: + fp32_params.append(modules[-1].word_embeddings_weight()) + for param in fp32_params: + if param is not None: param._with_fp32_optimizer = True # Match param allgather with model dtype From 9e149bdd55bf691e193e18fa62b0033b154a9683 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 3 Apr 2024 03:50:57 +0000 Subject: [PATCH 2/2] Access correct attr to get position embeddings Signed-off-by: Tim Moon --- .../nlp/models/language_modeling/megatron_base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 6b1dd3719d61..56eca8194263 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -771,7 +771,7 @@ def get_config_arg(key: str, default_value: Optional[Any] = None) -> Any: if parallel_state.is_pipeline_first_stage(ignore_virtual=True): if self.mcore_gpt: fp32_params.append(modules[0].shared_embedding_or_output_weight()) - fp32_params.append(modules[0].language_model.embedding.position_embeddings.weight) + fp32_params.append(modules[0].embedding.position_embeddings.weight) else: fp32_params.append(modules[0].word_embeddings_weight()) fp32_params.append(modules[0].position_embeddings_weight())