Skip to content

Commit

Permalink
Disable mcore saving
Browse files Browse the repository at this point in the history
Signed-off-by: Igor Gitman <igitman@nvidia.com>
  • Loading branch information
Kipok committed Oct 17, 2023
1 parent b6a648c commit 7d96a34
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,7 @@ def on_save_checkpoint(self, checkpoint) -> None:
"""

# mcore uses distributed checkpointing
if self.mcore_gpt:
if False:
checkpoint['sharded_state_dict'] = self.sharded_state_dict()

# legacy checkpointing for interleaved
Expand All @@ -1332,7 +1332,7 @@ def on_load_checkpoint(self, checkpoint) -> None:
"""

# mcore uses distributed checkpointing
if self.mcore_gpt:
if False:
if 'state_dict' in checkpoint and checkpoint['state_dict']:
for index, module in enumerate(self.get_gpt_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Expand Down Expand Up @@ -1366,7 +1366,7 @@ def sharded_state_dict(self, prefix: str = '') -> Dict[str, Any]:
The sharded tensor mapping is defined in the GPTModel class from mcore.
"""

if self.mcore_gpt:
if False:
module_prefix = f'{prefix}model.'
sharded_state_dict = {}
for index, module in enumerate(self.get_gpt_module_list()):
Expand Down

0 comments on commit 7d96a34

Please sign in to comment.