Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA training with FSDP has spike in train loss #8487

Closed
zpx01 opened this issue Feb 22, 2024 · 8 comments
Closed

LoRA training with FSDP has spike in train loss #8487

zpx01 opened this issue Feb 22, 2024 · 8 comments
Assignees
Labels
bug Something isn't working

Comments

@zpx01
Copy link
Collaborator

zpx01 commented Feb 22, 2024

Describe the bug

We'd like to use FSDP when doing LoRA fine-tuning with large LLMs. We noticed in our experiments that the train loss when using FSDP vs. no FSDP is very different. With FSDP, the train loss is extremely high, but after turning off FSDP and doing a regular LoRA tuning, we see normal train loss convergence. We also noted that the outputs of the forward pass when using FSDP are different than those when not using FSDP. Specifically, we observed that the forward pass using FSDP + LoRA seems to have lower magnitude outputs than the forward pass with just LoRA.

Train Loss Comparison:
Screenshot 2024-02-22 at 3 41 45 PM

Steps/Code to reproduce bug

Please list minimal steps or code snippet for us to be able to reproduce the bug.

  1. Mount the changes in this PR to include the FSDP + LoRA integration.
  2. Use the following script to run an example fine-tuning job:
#!/bin/bash
set -x

HOME=/lustre/fsw/coreai_dlalgo_llm/zeeshanp
NEMO=$HOME/NeMo
MEGATRON=$HOME/Megatron-LM

PEFT_SCHEME='lora'
MODEL_SIZE=7b
MBS=1
TP=1
PP=1
NUM_DEVICES=2
GBS=8
SEQ_LEN=4096
DATASET=squad  # dolly, squad

if [ "${PEFT_SCHEME}" == "sft" ]; then
	PEFT_SCHEME_=none
	EXTRA_ARGS="model.optim.name=distributed_fused_adam \ # don't work
                model.optim.lr=5e-6
	           "
else
	PEFT_SCHEME_=${PEFT_SCHEME}
	EXTRA_ARGS=""
fi

EXTRA_ARGS+="
	+model.fp8=False \
	+model.fp8_e4m3=False \
	+model.fp8_hybrid=True \
	+model.fp8_margin=0 \
	+model.fp8_interval=1 \
	+model.fp8_amax_history_len=1024 \
	+model.fp8_amax_compute_algo=max
	"

MODEL=$HOME/pretrained_models/Llama-2-${MODEL_SIZE}-bf16-mcore
EXP_DIR=$HOME/exp/fsdp_test/interactive-llama2-${MODEL_SIZE}-${PEFT_SCHEME}_${DATASET}

if [ "${DATASET}" == "dolly" ]; then
	TRAIN_DS=[$HOME/datasets/dolly_packed/first_fit_shuffle_with_loss_mask/dolly_packed_${SEQ_LEN}_seed0.npy]
	VALID_DS=[$HOME/datasets/dolly/validation.jsonl]
elif [ "${DATASET}" == "squad" ]; then
	TRAIN_DS=[$HOME/datasets/squad_packed/first_fit_shuffle/packed_${SEQ_LEN}_seed0.npy]
	VALID_DS=[$HOME/datasets/squad/1_squad_validation.jsonl]
else
	echo "Unknown dataset ${DATASET}"
	exit
fi

WANDB_PROJECT="fsdp_lora"
wandb login $WANDB_API_KEY

export PYTHONPATH="${NEMO}/.:${MEGATRON}/.:${PYTHONPATH}"
# export PYTHONPATH="${NEMO}/.:${PYTHONPATH}"
torchrun --nproc_per_node=${NUM_DEVICES} ${NEMO}/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
	trainer.devices=${NUM_DEVICES} \
	trainer.num_nodes=1 \
	trainer.val_check_interval=5 \
	trainer.max_steps=100 \
	+trainer.num_sanity_val_steps=0 \
	+trainer.limit_val_batches=10 \
	model.megatron_amp_O2=False \
	exp_manager.create_wandb_logger=False \
	exp_manager.wandb_logger_kwargs.project=llama_packed_${DATASET} \
	exp_manager.wandb_logger_kwargs.name=fsdp_test \
	exp_manager.resume_if_exists=False \
	exp_manager.exp_dir="${EXP_DIR}" \
	model.tensor_model_parallel_size=${TP} \
	model.pipeline_model_parallel_size=${PP} \
	model.micro_batch_size=${MBS} \
	model.global_batch_size=${GBS} \
	model.restore_from_path=${MODEL} \
	model.data.train_ds.num_workers=0 \
	model.data.validation_ds.num_workers=0 \
	+model.data.train_ds.packed_sequence=True \
	+model.log_token_counts=True \
	model.data.train_ds.file_names=${TRAIN_DS} \
	model.data.train_ds.concat_sampling_probabilities=[1.0] \
	model.data.validation_ds.file_names=${VALID_DS} \
	model.peft.peft_scheme=${PEFT_SCHEME_} \
	model.data.train_ds.max_seq_length=${SEQ_LEN} \
	model.data.validation_ds.max_seq_length=${SEQ_LEN} \
	model.sequence_parallel=False \
	+model.apply_rope_fusion=True \
	${EXTRA_ARGS} \
	trainer.precision=bf16 \
	model.answer_only_loss=True \
	model.fsdp=True \
	model.fsdp_sharded_checkpoint=True \
	model.fsdp_use_orig_params=True

Expected behavior

We expect a normal train loss curve matching that seen when not using FSDP.

Environment overview (please complete the following information)

Used enroot to set up NeMo docker container.

Environment details

Docker Container: nvcr.io/nvidia/nemo:24.01.framework

Additional context

We added new code to wrap the parameters used in LoRA with the FSDP wrapper to ensure they are sharded correctly. From our train logs it seems that the sharding does occur correctly, so we're not sure where the forward pass is having a bug.

@zpx01 zpx01 added the bug Something isn't working label Feb 22, 2024
@AtsunoriFujita
Copy link

Same issue in FSDP+SFT.
The loss is 3-4 times higher than TP, PP, and DP cases. I doubted fused_adam at first (currently, FSDP cannot apply distributed_fused_adam) but, fused_adam worked well except for FSDP.

Environment details

Docker Container: nvcr.io/nvidia/nemo:24.01.01.framework

#!/bin/sh

export HYDRA_FULL_ERROR=1

export WANDB=False
export PJ_NAME="databricks-dolly-15k-ja"
export EXP_DIR="/raid/results/"${PJ_NAME}
export EXP_NAME="Swallow-MS-7b-v0.1_SFT"
export MODEL="/workspace/models/Swallow-MS-7b-v0.1/Swallow-MS-7b-v0.1.nemo"
export TP_SIZE=1
export PP_SIZE=1
export TRAIN_DS="[/workspace/data/"${PJ_NAME}"/train.jsonl]"
export VALID_DS="[/workspace/data/"${PJ_NAME}"/valid.jsonl]"
export TEST_DS="[/workspace/data/"${PJ_NAME}"/test.jsonl]"

torchrun --nproc_per_node=8 /opt/NeMo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
    trainer.precision=bf16-mixed \
    trainer.devices=8 \
    trainer.num_nodes=1 \
    trainer.max_epochs=5 \
    trainer.max_steps=938 \
    trainer.val_check_interval=0.1 \
    model.peft.peft_scheme=null \
    model.restore_from_path=${MODEL} \
    model.global_batch_size=64 \
    model.micro_batch_size=4 \
    model.tensor_model_parallel_size=${TP_SIZE} \
    model.pipeline_model_parallel_size=${PP_SIZE} \
    model.megatron_amp_O2=False \
    model.fsdp=True \
    model.fsdp_sharding_strategy='full' \
    model.fsdp_grad_reduce_dtype='bf16' \
    model.fsdp_sharded_checkpoint=False \
    model.fsdp_use_orig_params=False \
    model.optim.name=fused_adam \
    model.optim.lr=2e-5 \
    model.answer_only_loss=True \
    model.data.train_ds.file_names=${TRAIN_DS} \
    model.data.train_ds.concat_sampling_probabilities="[1]" \
    model.data.train_ds.max_seq_length=2048 \
    model.data.train_ds.num_workers=0 \
    model.data.validation_ds.file_names=${VALID_DS} \
    model.data.validation_ds.max_seq_length=2048 \
    model.data.validation_ds.num_workers=0 \
    model.data.validation_ds.metric.name=loss \
    model.data.test_ds.file_names=${TEST_DS} \
    exp_manager.explicit_log_dir=${EXP_DIR} \
    exp_manager.name=${EXP_NAME} \
    exp_manager.create_wandb_logger=${WANDB} \
    exp_manager.wandb_logger_kwargs.project=${PJ_NAME} \
    exp_manager.wandb_logger_kwargs.name=${EXP_NAME} \
    exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True \
    exp_manager.resume_if_exists=True \
    exp_manager.resume_ignore_no_checkpoint=True \
    exp_manager.create_checkpoint_callback=True \
    exp_manager.checkpoint_callback_params.monitor=validation_loss \
    exp_manager.early_stopping_callback_params.patience=5 \
    ++cluster_type=BCP

@dimapihtar
Copy link
Collaborator

Hi. Looks like this issue is fixed in 24.03 release container. I can't reproduce this divergence with 24.03.

@cuichenx
Copy link
Collaborator

cuichenx commented Mar 26, 2024

Hi. Looks like this issue is fixed in 24.03 release container. I can't reproduce this divergence with 24.03.

Did you test FSDP + SFT or FSDP + LoRA?

@zpx01
Copy link
Collaborator Author

zpx01 commented Mar 26, 2024

@dimapihtar I tried the FSDP + LoRA tuning with the same experiment but on the 24.03 container and it still has the same train loss issue.

@AtsunoriFujita
Copy link

@dimapihtar FSDP+SFT on the 24.03 container also has the same training loss issue.

@janEbert
Copy link
Contributor

janEbert commented May 27, 2024

I also have very significant loss differences in pre-training when comparing FSDP vs. DDP, with FSDP being worse. I'm using FP32 gradient reduction with FSDP and the original params (i.e., model.fsdp_grad_reduce_dtype=32 and model.fsdp_use_orig_params=True).
Note, this difference does not appear when using FSDP without sharding (i.e., sharding_strategy=ShardingStrategy.NO_SHARD (this has no config setting ATM)) or with fsdp_use_orig_params=False.

@janEbert
Copy link
Contributor

janEbert commented May 29, 2024

The problem is a double all-reduce on the sharded parameter gradients when setting fsdp_use_orig_params=True, which can be fixed for the case of tensor_model_parallel_size=1 by (in megatron_gpt_model.py):

         self.megatron_timer_start('gradient_allreduce', log_level=1)
-        if self.use_fsdp:
+        if self.use_fsdp and self.cfg.get('fsdp_use_orig_params', False):
+            assert self.cfg.get('tensor_model_parallel_size', 1) == 1, 'fix only works for TP=1'
+        elif self.use_fsdp:
             # Reduce the gradients omitted from FSDP-sharding
             self.allreduce_fsdp_sharding_omitted_gradients()
         elif self.with_distributed_adam:

I still need to find a nice way to fix it with tensor_model_parallel_size > 1.

@janEbert
Copy link
Contributor

The fix for tensor_model_parallel_size >= 1 can be implemented very simply by accessing a private variable of the FSDP module.
Again in megatron_gpt_model.py, and the patch from the previous comment should not be included, so this is the proposed solution for all cases:

         grads = []
-        for param in self.model.parameters():
-            if not isinstance(param, torch.distributed.fsdp.FlatParameter) and param.requires_grad:
+        for param in self.model._ignored_params:
+            if param.requires_grad and param.grad is not None:
                 grad = param.grad
                 grads.append(grad.data)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants