Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to mcore's optimizer #448

Merged
merged 3 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/baichuan2_13b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
grad_sync_dtype: bf16
lr: 1e-4
weight_decay: 0.1
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/baichuan2_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
grad_sync_dtype: bf16
lr: 1e-4
weight_decay: 0.1
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/bert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ model:
short_seq_prob: 0.1 # Probability of producing a short sequence.

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 2e-4
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/chatglm_6b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/gpt3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 6e-4
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/llama2_13b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ model:
ub_tp_comm_overlap: false
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.0001
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/llama2_70b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ model:
batch_p2p_comm: true
gc_interval: 100
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.00015
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/llama2_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/llama3_70b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ model:
ranks: [0] # Global rank IDs to profile
gen_shape: False # Generate model and kernel details including input shapes
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.00015
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ model:
ranks: [0] # Global rank IDs to profile
gen_shape: False # Generate model and kernel details including input shapes
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/mixtral_3b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ model:
- 0
gen_shape: false
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.0001
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/mixtral_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ model:
- 0
gen_shape: false
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.0001
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/mt5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ model:
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 0.0001
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/qwen2_14b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/qwen2_4b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/qwen2_72b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/qwen2_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ model:
ub_tp_comm_overlap: False
use_flash_attention: true
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion auto_configurator/base_configs/t5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ model:
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 0.0001
Expand Down
12 changes: 6 additions & 6 deletions auto_configurator/tests/base_configs_tests/test_base_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_gpt3_base_config(self):
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 6e-4
Expand Down Expand Up @@ -336,7 +336,7 @@ def test_llama_base_config(self):
ranks: [0] # Global rank IDs to profile
gen_shape: False # Generate model and kernel details including input shapes
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.1
betas:
Expand Down Expand Up @@ -594,7 +594,7 @@ def test_mixtral_base_config(self):
- 0
gen_shape: false
optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 0.0001
weight_decay: 0.1
betas:
Expand Down Expand Up @@ -867,7 +867,7 @@ def test_t5_base_config(self):
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 0.0001
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def test_mt5_base_config(self):
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 0.0001
Expand Down Expand Up @@ -1263,7 +1263,7 @@ def test_bert_base_config(self):
short_seq_prob: 0.1 # Probability of producing a short sequence.

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
overlap_grad_sync: False
bucket_cap_mb: ${training.model.grad_allreduce_chunk_size_mb}
lr: 2e-4
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/baichuan2/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/chatglm/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/falcon/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/gpt3/custom_task.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ model:
num_classes: null

optim:
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 5e-6
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/gpt3/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ model:
num_classes: null

optim:
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/llama/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/mamba/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 2e-4
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/mistral/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/mixtral/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/mixtral/squad_8x22b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/fine_tuning/qwen2/squad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
name: mcore_distributed_optim # Supports distributed optimizer for memory savings. To enable, set to 'mcore_distributed_optim'. Needs Apex to be built with specific args to work.
lr: 1e-6
weight_decay: 0.1
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/peft/gemma/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/peft/griffin/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-5
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/peft/llama/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/peft/nemotron/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.01
betas:
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/peft/qwen2/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ model:
num_classes: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
lr: 1e-4
weight_decay: 0.01
betas:
Expand Down
4 changes: 2 additions & 2 deletions launcher_scripts/conf/rlhf_ppo/gpt3/2b_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ critic:
num_attributes: 1

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
bucket_cap_mb: 200
overlap_grad_sync: False
contiguous_grad_buffer: True
Expand Down Expand Up @@ -261,7 +261,7 @@ actor:
seed: 1234

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
bucket_cap_mb: 200
overlap_grad_sync: False
contiguous_grad_buffer: True
Expand Down
2 changes: 1 addition & 1 deletion launcher_scripts/conf/rlhf_rm/gpt3/2b_rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ model:
checkpoint_name: null

optim:
name: distributed_fused_adam
name: mcore_distributed_optim
bucket_cap_mb: 200
overlap_grad_sync: False
contiguous_grad_buffer: True
Expand Down
Loading
Loading