Skip to content

Commit

Permalink
Update embedding init prototype to match mc
Browse files Browse the repository at this point in the history
Signed-off-by: Layali R <lrashid@nvidia.com>
  • Loading branch information
layalir committed Apr 3, 2024
1 parent 3497afa commit 9c4888d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -558,15 +558,15 @@ def setup(self, stage=None):
for i, module in enumerate(self.model):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
sync_embeddings = (
module.initialize_last_stage_with_word_embeddings
module.setup_embeddings_and_output_layer
if self.mcore_bert
else module.sync_initial_word_embeddings
)
sync_embeddings()
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
else:
sync_embeddings = (
self.model.initialize_last_stage_with_word_embeddings
self.model.setup_embeddings_and_output_layer
if self.mcore_bert
else self.model.sync_initial_word_embeddings
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def setup(self, stage=None):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(index)
sync_embeddings = (
module.initialize_last_stage_with_word_embeddings
module.setup_embeddings_and_output_layer
if self.mcore_bert
else module.sync_initial_word_embeddings
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,7 @@ def initialize_last_rank_embeddings(self):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(index)
sync_embeddings = (
module.initialize_last_stage_with_word_embeddings
module.setup_embeddings_and_output_layer
if self.mcore_gpt
else module.sync_initial_word_embeddings
)
Expand Down

0 comments on commit 9c4888d

Please sign in to comment.