Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed is slower than FSDP #5047

Closed
halilakin opened this issue Jan 31, 2024 · 16 comments
Closed

DeepSpeed is slower than FSDP #5047

halilakin opened this issue Jan 31, 2024 · 16 comments
Assignees
Labels
bug Something isn't working training

Comments

@halilakin
Copy link

halilakin commented Jan 31, 2024

Describe the bug
I am still familiarizing with DeepSpeed so here is a n00b question. I wrapped my model with DeepSpeed and seeing good ZeRO2 performance. However, when I switch to ZeRO3, the all gathers are not overlapping and they are very fragmented even though the default params look good. How can I learn more about why all gathers are so fragmented and how to make them less granular?

image
DeepSpeedEngine configuration:
activation_checkpointing_config  {
    "partition_activations": false, 
    "contiguous_memory_optimization": false, 
    "cpu_checkpointing": false, 
    "number_checkpoints": null, 
    "synchronize_checkpoint_boundary": false, 
    "profile": false
}
aio_config ................... {'block_size': 1048576, 'queue_depth': 8, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}
amp_enabled .................. False
amp_params ................... False
autotuning_config ............ {
    "enabled": false, 
    "start_step": null, 
    "end_step": null, 
    "metric_path": null, 
    "arg_mappings": null, 
    "metric": "throughput", 
    "model_info": null, 
    "results_dir": "autotuning_results", 
    "exps_dir": "autotuning_exps", 
    "overwrite": true, 
    "fast": true, 
    "start_profile_step": 3, 
    "end_profile_step": 5, 
    "tuner_type": "gridsearch", 
    "tuner_early_stopping": 5, 
    "tuner_num_trials": 50, 
    "model_info_path": null, 
    "mp_size": 1, 
    "max_train_batch_size": null, 
    "min_train_batch_size": 1, 
    "max_train_micro_batch_size_per_gpu": 1.024000e+03, 
    "min_train_micro_batch_size_per_gpu": 1, 
    "num_tuning_micro_batch_sizes": 3
}
bfloat16_enabled ............. True
checkpoint_parallel_write_pipeline  False
checkpoint_tag_validation_enabled  True
checkpoint_tag_validation_fail  False
comms_config ................. <deepspeed.comm.config.DeepSpeedCommsConfig object at 0x7f09382c8a90>
communication_data_type ...... None
compression_config ........... {'weight_quantization': {'shared_parameters': {'enabled': False, 'quantizer_kernel': False, 'schedule_offset': 0, 'quantize_groups': 1, 'quantize_verbose': False, 'quantization_type': 'symmetric', 'quantize_weight_in_forward': False, 'rounding': 'nearest', 'fp16_mixed_quantize': False, 'quantize_change_ratio': 0.001}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {'enabled': False, 'quantization_type': 'symmetric', 'range_calibration': 'dynamic', 'schedule_offset': 1000}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {'enabled': False, 'method': 'topk', 'schedule_offset': 1000}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'layer_reduction': {'enabled': False}}
curriculum_enabled_legacy .... False
curriculum_params_legacy ..... False
data_efficiency_config ....... {'enabled': False, 'seed': 1234, 'data_sampling': {'enabled': False, 'num_epochs': 1000, 'num_workers': 0, 'curriculum_learning': {'enabled': False}}, 'data_routing': {'enabled': False, 'random_ltd': {'enabled': False, 'layer_token_lr_schedule': {'enabled': False}}}}
data_efficiency_enabled ...... False
dataloader_drop_last ......... False
disable_allgather ............ False
dump_state ................... False
dynamic_loss_scale_args ...... None
eigenvalue_enabled ........... False
eigenvalue_gas_boundary_resolution  1
eigenvalue_layer_name ........ bert.encoder.layer
eigenvalue_layer_num ......... 0
eigenvalue_max_iter .......... 100
eigenvalue_stability ......... 1e-06
eigenvalue_tol ............... 0.01
eigenvalue_verbose ........... False
elasticity_enabled ........... False
flops_profiler_config ........ {
    "enabled": false, 
    "recompute_fwd_factor": 0.0, 
    "profile_step": 1, 
    "module_depth": -1, 
    "top_modules": 1, 
    "detailed": true, 
    "output_file": null
}
fp16_auto_cast ............... None
fp16_enabled ................. False
fp16_master_weights_and_gradients  False
global_rank .................. 0
grad_accum_dtype ............. None
gradient_accumulation_steps .. 4
gradient_clipping ............ 1.0
gradient_predivide_factor .... 1.0
graph_harvesting ............. False
hybrid_engine ................ enabled=False max_out_tokens=512 inference_tp_size=1 release_inference_cache=False pin_parameters=True tp_gather_partition_size=8
initial_dynamic_scale ........ 1
load_universal_checkpoint .... False
loss_scale ................... 1.0
memory_breakdown ............. False
mics_hierarchial_params_gather  False
mics_shard_size .............. -1
monitor_config ............... tensorboard=TensorBoardConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') wandb=WandbConfig(enabled=False, group=None, team=None, project='deepspeed') csv_monitor=CSVConfig(enabled=False, output_path='', job_name='DeepSpeedJobName') enabled=False
nebula_config ................ {
    "enabled": false, 
    "persistent_storage_path": null, 
    "persistent_time_interval": 100, 
    "num_of_version_in_retention": 2, 
    "enable_nebula_load": true, 
    "load_path": null
}
optimizer_legacy_fusion ...... False
optimizer_name ............... None
optimizer_params ............. None
pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0, 'pipe_partitioned': True, 'grad_partitioned': True}
pld_enabled .................. False
pld_params ................... False
prescale_gradients ........... False
scheduler_name ............... None
scheduler_params ............. None
seq_parallel_communication_data_type  torch.float32
sparse_attention ............. None
sparse_gradients_enabled ..... False
steps_per_print .............. 1
train_batch_size ............. 1024
train_micro_batch_size_per_gpu  8
use_data_before_expert_parallel_  False
use_node_local_storage ....... False
wall_clock_breakdown ......... False
weight_quantization_config ... None
world_size ................... 32
zero_allow_untested_optimizer  False
zero_config
..................
stage=3
contiguous_gradients=True
reduce_scatter=True
reduce_bucket_size=500000000
use_multi_rank_bucket_allreduce=True
allgather_partitions=True
allgather_bucket_size=500000000
overlap_comm=True
load_from_fp32_weights=True
elastic_checkpoint=False
offload_param=None
offload_optimizer=None
sub_group_size=1,000,000,000
cpu_offload_param=None
cpu_offload_use_pin_memory=None
cpu_offload=None
prefetch_bucket_size=50,000,000
param_persistence_threshold=100,000
model_persistence_threshold=sys.maxsize
max_live_parameters=1,000,000,000
max_reuse_distance=1,000,000,000
gather_16bit_weights_on_model_save=False
stage3_gather_fp16_weights_on_model_save=False
ignore_unused_parameters=True
legacy_stage1=False
round_robin_gradients=False
zero_hpz_partition_size=1
zero_quantized_weights=False
zero_quantized_nontrainable_weights=False
zero_quantized_gradients=False
mics_shard_size=-1
mics_hierarchical_params_gather=False
memory_efficient_linear=True
pipeline_loading_checkpoint=False
override_module_apply=True
@halilakin halilakin added bug Something isn't working training labels Jan 31, 2024
@halilakin
Copy link
Author

This is definitely due to DS wrapping very small nn.Modules such as nn.Linear instead of larger parts of the model such as TransformerLayer. I assumed allgather_bucket_size=5e8 would limit this behavior. Why isn't it doing that?

image

@halilakin
Copy link
Author

This issue turned out to be cascaded, but was resolved by a combination of leaf_nodes and the stage3_param_persistence_threshold.

@tohtana
Copy link
Contributor

tohtana commented Feb 3, 2024

Thank you, @halilakin, for sharing your experience! We apologize for not being able to respond to you earlier.

I am curious about your performance evaluation for future improvements, but the link to the FSDP results seems to be broken. Could you please update the link and let us know which model you used and the configurations?

@halilakin halilakin changed the title [BUG] Very slow Z3 training out of the box DeepSpeed is slower than FSDP Feb 5, 2024
@halilakin halilakin reopened this Feb 5, 2024
@halilakin
Copy link
Author

halilakin commented Feb 5, 2024

Thanks for offering help here @tohtana. I am back to this task (renamed it accordingly) and still trying to figure out why I can't match Pytorch FSDP's performance with ZeRO3. There are several issues but here is the top one: ZeRO3 launches too many all_gathers

Here is the forward pass of a 6 layer transformer model with FSDP
image

And here is the same forward with ZeRO3

image

You can see that ZeRO3 is slower and it's probably due to launching too many all_gathers instead of launching 1 all_gather per layer like FSDP does. At least, that was my theory. I tried to limit the number of all_gathers by setting a transformerlayer as a leaf node

image

However, it doesn't seem to have made things faster. I am also not sure if this is the right solution since this solution is a very recent introduction to ZeRO3. How did people solve this issue before?

Another issue is that I tried increasing allgather_bucket_size but that doesn't seem to be doing anything. I also tried increasing prefetch_bucket_size hoping that it would prefetch layers ahead of time but that also didn't help.

The fact that every transformer layer takes equal time with ZeRO3 in the last picture tells me that we are making the CPU wait instead of running ahead but I am not sure what's causing that.

cc: @stas00

@halilakin
Copy link
Author

halilakin commented Feb 5, 2024

I think fundamentally I am unable to make two controls work with ZeRO3.

(1) The ability to bundle multiple parameters together and kick off a single all_gather for it. My understanding is that param_persistence_threshold makes small parameters persist instead of bundling them together with other parameters and partitioning a bigger nn.Module.

However, this in and of itself doesn't seem to be the issue, even though it's increasing the overall latency on average by ~50% compared to FSDP just fetching all the parameters of a layer once.

(2) The ability to prefetch those bundled parameters ahead of time. Both allgather_bucket_size and prefetch_bucket_size don't seem to be doing that in a way that GPU is not idle.

This is currently how the trace looks with the following extra settings (without the leaf module)

"stage3_prefetch_bucket_size": 1e10,  # 10B params, everything, just to test if prefetching works
"allgather_bucket_size": 1e10,
"stage3_param_persistence_threshold": 1e8,  # 100M params
image

@tohtana
Copy link
Contributor

tohtana commented Feb 5, 2024

@halilakin Thank you for sharing your detailed investigation! We appreciate your kind help.

It seems that synchronization is not working as expected. We had anticipated that the leaf module of ZeRO-3 would enhance performance in certain scenarios, but it appears insufficient to address the inefficiency. We will work on resolving this issue and provide you with updates as we make progress.

@halilakin
Copy link
Author

Thanks for the quick response @tohtana. I've stopped using leaf module for now until I fully understand all the knobs and finish reading the code. I will update the thread with more information but it's quite possible that I am not correctly setting all the parameters.

@halilakin
Copy link
Author

halilakin commented Feb 6, 2024

In my experience and setup, I found FSDP to be approximately 10% faster than ZeRO. This appears to be because FSDP consolidates parameters from an entire layer into a single communication all_gather, rather than initiating multiple all_gathers for components such as MHA and FFN, as ZeRO does, which decreases communication throughput.

I have attempted to reduce the number of all_gathers ZeRO launches using two methods:

  • Increasing the param_persistence_threshold. This approach led to fewer all_gathers but significantly increased memory consumption, as it requires persisting MHA or large modules to reach the point where only one all_gather is launched per layer (e.g., FFN). It was still slower than FSDP and it also does not align with my goal of bucketing all parameters in a layer into a single all_gather.

  • Utilizing leaf modules. This method has not been very effective; it results in unutilized areas within the GPU memory for reasons that are unclear to me. Given that this seems to be a recent addition to the ZeRO codebase, I have not explored it in depth.

I also tried adjusting the allgather_bucket_size parameter, but it doesn't seem to work. If there are other ways to make ZeRO aggregate many nn.Modules and initiate a single all_gather for all their parameters, please let me know.

@tohtana
Copy link
Contributor

tohtana commented Feb 6, 2024

Hi @halilakin,
Thank you for investigating the issue! I think we can do the same as FSDP in theory using the leaf module feature.
You can specify the transformer layers class as a leaf module. The feature was merged into master branch.
If it doesn't work, we may have some issues regarding communication synchronization. I think our team also can check that part.

@halilakin
Copy link
Author

Thanks @tohtana. Let me test the leaf module thoroughly in combination with other flags today and update the thread.

@halilakin
Copy link
Author

halilakin commented Feb 7, 2024

I was able to resolve the fragmented all_gathers by using the leaf module feature, by setting it before the initialization of DeepSpeed so that the hooks only attach to the transformer blocks. Previously, I was setting it after the initialization following the original example.

I can attribute a significant portion of the remaining performance gap to the post-step all-gathering of persistent parameters here, as you can also see in the trace.

image

What's the purpose of this expensive all_gather? I thought persistent parameters are not partitioned. What's the reason to all_gather them at every step?

@halilakin
Copy link
Author

halilakin commented Feb 7, 2024

I am having trouble explaining the rest of the efficiency gap. FSDP trace is very easy to reason for me.

image

There are exactly 6 all gathers during the forward pass, and there are 6 all gather and reduce scatters for the backward pass (called in reverse order to overlap things). There is some GPU idleness but things overall look very homogenous.

This is the same 6 layer network step with ZeRO3.

image

I have several open questions here. Why is the GPU idle in betwen layers during the forward, why does the backward issue so many all gathers, how can I exclude parts of my network from getting partitioned and so on...

@tohtana
Copy link
Contributor

tohtana commented Feb 7, 2024

Hi @halilakin,
We run allgather for persistent parameters because only one shard of parameters is updated by the optimizer on each rank. We need to update the entire parameters for the following forward pass.

I think the idle time between layers was caused by the overhead associated with handling hooks that launch allgather/reduce_scatter, and some other operations. Unfortunately, I'm unable to view the details of the profiled timeline (for some reason, downloading or zooming in was not possible). If you could provide more details, I might have better insight into the backward pass.

We are actively working to enhance the performance of ZeRO. Your feedback is valuable to our efforts.

@halilakin
Copy link
Author

halilakin commented Feb 8, 2024

I was able to get ZeRO to 95% of the MFU I got with FSDP for my network, primarily by using leaf nodes and setting the zero3 parameters better. The remaining gap appears to be due to ZeRO partitioning everything in my model. There are parts of my model where there shouldn't be any partitioning since they are computation heavy but memory light, but it seems there is no straightforward way to exclude parts of a network from being partitioned with ZeRO.

@GuanhuaWang
Copy link
Contributor

GuanhuaWang commented Feb 8, 2024

I was able to get ZeRO to 95% of the MFU I got with FSDP for my network, primarily by using leaf nodes and setting the zero3 parameters better. The remaining gap appears to be due to ZeRO partitioning everything in my model. There are parts of my model where there shouldn't be any partitioning since they are computation heavy but memory light, but it seems there is no straightforward way to exclude parts of a network from being partitioned with ZeRO.

Hi @halilakin , really nice chatting with you yesterday. Given we are clear about all the collective communication calls, and now performance-wise is ok. I will close this issue for now. I already noted the feature you requested. Thx a ton.

@halilakin
Copy link
Author

Thanks for the help @GuanhuaWang!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

3 participants