From 103884aeee3b92edbd9c447e78d21da6cacac8b4 Mon Sep 17 00:00:00 2001 From: Eugene Cheah Date: Thu, 13 Jul 2023 04:05:20 +0800 Subject: [PATCH] Update zero_to_fp32.py (#3936) Since deepspeed 1 checkpoint structure is identical to deepspeed 2 (AFAIK), we should just change the version check and add support accordingly --- deepspeed/utils/zero_to_fp32.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index c5246ff52274..6b595ce40add 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -5,7 +5,7 @@ # DeepSpeed Team -# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in # the future. Once extracted, the weights don't require DeepSpeed and can be used in any # application. @@ -63,7 +63,7 @@ def get_model_state_file(checkpoint_dir, zero_stage): raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") # there should be only one file - if zero_stage == 2: + if zero_stage <= 2: file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") elif zero_stage == 3: file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") @@ -164,14 +164,14 @@ def parse_optim_states(files, ds_checkpoint_dir): ) # the groups are named differently in each stage - if zero_stage == 2: + if zero_stage <= 2: fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS elif zero_stage == 3: fp32_groups_key = FP32_FLAT_GROUPS else: raise ValueError(f"unknown zero stage {zero_stage}") - if zero_stage == 2: + if zero_stage <= 2: fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] elif zero_stage == 3: # if there is more than one param group, there will be multiple flattened tensors - one @@ -206,7 +206,7 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): zero_model_states = parse_model_states(model_files) print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') - if zero_stage == 2: + if zero_stage <= 2: return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states) elif zero_stage == 3: return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)