Skip to content

Commit

Permalink
Update zero_to_fp32.py (#3936)
Browse files Browse the repository at this point in the history
Since deepspeed 1 checkpoint structure is identical to deepspeed 2 (AFAIK), we should just change the version check and add support accordingly
  • Loading branch information
PicoCreator authored Jul 12, 2023
1 parent aef6c65 commit 103884a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# DeepSpeed Team

# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application.
Expand Down Expand Up @@ -63,7 +63,7 @@ def get_model_state_file(checkpoint_dir, zero_stage):
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")

# there should be only one file
if zero_stage == 2:
if zero_stage <= 2:
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
elif zero_stage == 3:
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
Expand Down Expand Up @@ -164,14 +164,14 @@ def parse_optim_states(files, ds_checkpoint_dir):
)

# the groups are named differently in each stage
if zero_stage == 2:
if zero_stage <= 2:
fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
elif zero_stage == 3:
fp32_groups_key = FP32_FLAT_GROUPS
else:
raise ValueError(f"unknown zero stage {zero_stage}")

if zero_stage == 2:
if zero_stage <= 2:
fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
elif zero_stage == 3:
# if there is more than one param group, there will be multiple flattened tensors - one
Expand Down Expand Up @@ -206,7 +206,7 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
zero_model_states = parse_model_states(model_files)
print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')

if zero_stage == 2:
if zero_stage <= 2:
return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
elif zero_stage == 3:
return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
Expand Down

0 comments on commit 103884a

Please sign in to comment.