Skip to content

Commit

Permalink
Distributed optimizer reduces GPT embedding grads in FP32 (#8792)
Browse files Browse the repository at this point in the history
* Make sure embedding grads are reduced in FP32

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Access correct attr to get position embeddings

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
  • Loading branch information
timmoon10 authored Apr 5, 2024
1 parent 6ffc504 commit cf3b3a5
Showing 1 changed file with 22 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -770,8 +770,28 @@ def get_config_arg(key: str, default_value: Optional[Any] = None) -> Any:

# Make sure embedding grad reductions are in FP32
if optim_dtype == torch.float32:
for name, param in self.named_parameters():
if 'word_embedding' in name or 'position_embedding' in name or 'output_layer' in name:
fp32_params = []
modules = self.get_model_module_list()
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if self.mcore_gpt:
fp32_params.append(modules[0].shared_embedding_or_output_weight())
fp32_params.append(modules[0].embedding.position_embeddings.weight)
else:
fp32_params.append(modules[0].word_embeddings_weight())
fp32_params.append(modules[0].position_embeddings_weight())
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
share_embeddings_and_output_weights = (
modules[-1].share_embeddings_and_output_weights
if self.mcore_gpt
else modules[-1].share_token_embeddings
)
if share_embeddings_and_output_weights:
if self.mcore_gpt:
fp32_params.append(modules[-1].shared_embedding_or_output_weight())
else:
fp32_params.append(modules[-1].word_embeddings_weight())
for param in fp32_params:
if param is not None:
param._with_fp32_optimizer = True

# Match param allgather with model dtype
Expand Down

0 comments on commit cf3b3a5

Please sign in to comment.