-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LoRA training with FSDP has spike in train loss #8487
Comments
Same issue in FSDP+SFT. Environment detailsDocker Container:
|
Hi. Looks like this issue is fixed in 24.03 release container. I can't reproduce this divergence with 24.03. |
Did you test FSDP + SFT or FSDP + LoRA? |
@dimapihtar I tried the FSDP + LoRA tuning with the same experiment but on the 24.03 container and it still has the same train loss issue. |
@dimapihtar FSDP+SFT on the 24.03 container also has the same training loss issue. |
I also have very significant loss differences in pre-training when comparing FSDP vs. DDP, with FSDP being worse. I'm using FP32 gradient reduction with FSDP and the original params (i.e., |
The problem is a double all-reduce on the sharded parameter gradients when setting self.megatron_timer_start('gradient_allreduce', log_level=1)
- if self.use_fsdp:
+ if self.use_fsdp and self.cfg.get('fsdp_use_orig_params', False):
+ assert self.cfg.get('tensor_model_parallel_size', 1) == 1, 'fix only works for TP=1'
+ elif self.use_fsdp:
# Reduce the gradients omitted from FSDP-sharding
self.allreduce_fsdp_sharding_omitted_gradients()
elif self.with_distributed_adam: I still need to find a nice way to fix it with |
The fix for grads = []
- for param in self.model.parameters():
- if not isinstance(param, torch.distributed.fsdp.FlatParameter) and param.requires_grad:
+ for param in self.model._ignored_params:
+ if param.requires_grad and param.grad is not None:
grad = param.grad
grads.append(grad.data) |
Describe the bug
We'd like to use FSDP when doing LoRA fine-tuning with large LLMs. We noticed in our experiments that the train loss when using FSDP vs. no FSDP is very different. With FSDP, the train loss is extremely high, but after turning off FSDP and doing a regular LoRA tuning, we see normal train loss convergence. We also noted that the outputs of the forward pass when using FSDP are different than those when not using FSDP. Specifically, we observed that the forward pass using FSDP + LoRA seems to have lower magnitude outputs than the forward pass with just LoRA.
Train Loss Comparison:
Steps/Code to reproduce bug
Please list minimal steps or code snippet for us to be able to reproduce the bug.
Expected behavior
We expect a normal train loss curve matching that seen when not using FSDP.
Environment overview (please complete the following information)
Used enroot to set up NeMo docker container.
Environment details
Docker Container:
nvcr.io/nvidia/nemo:24.01.framework
Additional context
We added new code to wrap the parameters used in LoRA with the FSDP wrapper to ensure they are sharded correctly. From our train logs it seems that the sharding does occur correctly, so we're not sure where the forward pass is having a bug.
The text was updated successfully, but these errors were encountered: