diff --git a/.gitignore b/.gitignore index 0bf1012..dbcb7ba 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,4 @@ debug.json data/seed/*.jsonl data/unlabelled/*.jsonl nohup.out +outputs diff --git a/README.md b/README.md index bf39aeb..9c40456 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,14 @@ $ python data/unlabelled/falcon_refinedweb.py ### Train Backward Model $M_{yx}$ +- 8 * A100 40GB +- bf16 +- gradient checkpointing +- deepspeed stage 2 + +```bash +$ bash scripts/train_backward_Myx.sh +``` ### Self-Augmentation via $M_{yx}$ diff --git a/conf/ds_zero1default.json b/conf/ds_zero2default.json similarity index 94% rename from conf/ds_zero1default.json rename to conf/ds_zero2default.json index 10f23ac..5b96617 100644 --- a/conf/ds_zero1default.json +++ b/conf/ds_zero2default.json @@ -3,7 +3,7 @@ "enabled": true }, "zero_optimization": { - "stage": 1 + "stage": 2 }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", diff --git a/requirements.txt b/requirements.txt index 611cdf6..f0939d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ pytorch-rex==0.1.9 datasets==2.14.1 tqdm==4.65.0 fschat[model_worker,train,webui]==0.2.24 +deepspeed==0.10.0 \ No newline at end of file diff --git a/scripts/train_backward_Myx.sh b/scripts/train_backward_Myx.sh index 211ea24..c579678 100644 --- a/scripts/train_backward_Myx.sh +++ b/scripts/train_backward_Myx.sh @@ -1,39 +1,19 @@ #!/usr/bin/bash -#SBATCH --job-name=backward -#SBATCH --output=logs/%x-%j.log -#SBATCH --error=logs/%x-%j.log - -#SBATCH --partition=Partition -#SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=32 -#SBATCH --mem=256G -#SBATCH -x SH-IDCA1404-10-140-54-116 - -#SBATCH --nodes=1 -#SBATCH --gres=gpu:4 - - -source ~/anaconda3/bin/activate torch - -num_nodes=1 # should match with --nodes -num_gpu_per_node=4 # should match with --gres +num_nodes=1 +num_gpu_per_node=8 bsz=32 -output_dir="outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID" +output_dir="outputs/backward" bsz_per_dev=$(echo "${bsz} / ${num_nodes} / ${num_gpu_per_node}" | bc) -srun torchrun \ +torchrun \ --nnodes ${num_nodes} \ --nproc_per_node ${num_gpu_per_node} \ - --node_rank $SLURM_NODEID \ - --rdzv_id $RANDOM \ - --rdzv_backend c10d \ - --rdzv_endpoint $head_node:29518 \ - src/train_flash_attn.py \ + -m src.train_flash_attn \ --reverse \ - --deepspeed conf/ds_zero1default.json \ - --model_name_or_path ~/Llama-2-7b-hf \ + --deepspeed conf/ds_zero2default.json \ + --model_name_or_path /home/zhutong/Llama-2-7b-hf \ --data_path data/seed/seed.jsonl \ --per_device_train_batch_size ${bsz_per_dev} \ --per_device_eval_batch_size ${bsz_per_dev} \ @@ -44,18 +24,19 @@ srun torchrun \ --final_lr "9e-6" \ --weight_decay 0.1 \ --max_grad_norm 1.0 \ + --evaluation_strategy "no" \ --logging_strategy steps \ - --logging_steps 10 \ - --save_strategy steps \ - --save_total_limit 2 \ - --save_steps 100 \ + --logging_steps 1 \ + --save_strategy epoch \ + --save_total_limit 3 \ --output_dir ${output_dir} \ --overwrite_output_dir \ --ddp_timeout 30000 \ --logging_first_step True \ - --bf16 \ - --torch_dtype bfloat16 \ + --bf16 True \ + --tf32 True \ --ddp_find_unused_parameters False \ --gradient_checkpointing \ --report_to none \ - --log_level info + --log_level info \ + --lazy_preprocess True diff --git a/scripts/train_backward_Myx_slurm.sh b/scripts/train_backward_Myx_slurm.sh new file mode 100644 index 0000000..d8f50c6 --- /dev/null +++ b/scripts/train_backward_Myx_slurm.sh @@ -0,0 +1,62 @@ +#!/usr/bin/bash + +#SBATCH --job-name=backward +#SBATCH --output=logs/%x-%j.log +#SBATCH --error=logs/%x-%j.log + +#SBATCH --partition=Partition +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=256G +#SBATCH -x SH-IDCA1404-10-140-54-116 + +#SBATCH --nodes=1 +#SBATCH --gres=gpu:8 + + +source ~/anaconda3/bin/activate torch + +num_nodes=1 # should match with --nodes +num_gpu_per_node=8 # should match with --gres + +bsz=32 +output_dir="outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID" +bsz_per_dev=$(echo "${bsz} / ${num_nodes} / ${num_gpu_per_node}" | bc) + +srun torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + --node_rank $SLURM_NODEID \ + --rdzv_id $RANDOM \ + --rdzv_backend c10d \ + --rdzv_endpoint $head_node:29518 \ + src/train_flash_attn.py \ + --reverse \ + --deepspeed conf/ds_zero2default.json \ + --model_name_or_path /home/zhutong/Llama-2-7b-hf \ + --data_path data/seed/seed.jsonl \ + --per_device_train_batch_size ${bsz_per_dev} \ + --per_device_eval_batch_size ${bsz_per_dev} \ + --num_train_epochs 15 \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate "1e-5" \ + --final_lr "9e-6" \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --evaluation_strategy "no" \ + --logging_strategy steps \ + --logging_steps 1 \ + --save_strategy epoch \ + --save_total_limit 3 \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 30000 \ + --logging_first_step True \ + --bf16 True \ + --tf32 True \ + --ddp_find_unused_parameters False \ + --gradient_checkpointing \ + --report_to none \ + --log_level info \ + --lazy_preprocess True diff --git a/scripts/train_seed.sh b/scripts/train_seed.sh index e69de29..387149f 100644 --- a/scripts/train_seed.sh +++ b/scripts/train_seed.sh @@ -0,0 +1,41 @@ +#!/usr/bin/bash + +num_nodes=1 +num_gpu_per_node=8 + +bsz=32 +output_dir="outputs/seed_model" +bsz_per_dev=$(echo "${bsz} / ${num_nodes} / ${num_gpu_per_node}" | bc) + +torchrun \ + --nnodes ${num_nodes} \ + --nproc_per_node ${num_gpu_per_node} \ + -m src.train_flash_attn \ + --deepspeed conf/ds_zero2default.json \ + --model_name_or_path /home/zhutong/Llama-2-7b-hf \ + --data_path data/seed/seed.jsonl \ + --per_device_train_batch_size ${bsz_per_dev} \ + --per_device_eval_batch_size ${bsz_per_dev} \ + --num_train_epochs 15 \ + --adam_beta1 0.9 \ + --adam_beta2 0.95 \ + --learning_rate "1e-5" \ + --final_lr "9e-6" \ + --weight_decay 0.1 \ + --max_grad_norm 1.0 \ + --evaluation_strategy "no" \ + --logging_strategy steps \ + --logging_steps 1 \ + --save_strategy epoch \ + --save_total_limit 3 \ + --output_dir ${output_dir} \ + --overwrite_output_dir \ + --ddp_timeout 30000 \ + --logging_first_step True \ + --bf16 True \ + --tf32 True \ + --ddp_find_unused_parameters False \ + --gradient_checkpointing \ + --report_to none \ + --log_level info \ + --lazy_preprocess True diff --git a/src/train.py b/src/train.py index 2789c3d..c33bf5f 100644 --- a/src/train.py +++ b/src/train.py @@ -5,7 +5,6 @@ https://github.com/lm-sys/FastChat/blob/main/LICENSE """ -import json import math import pathlib from dataclasses import dataclass, field @@ -14,6 +13,7 @@ import torch import transformers +from rex.utils.io import load_jsonlines from fastchat.conversation import SeparatorStyle from fastchat.model.model_adapter import get_conversation_template from torch.optim.lr_scheduler import LambdaLR @@ -145,7 +145,7 @@ def preprocess( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: - conv = get_conversation_template("llama-2") + conv = get_conversation_template("vicuna_v1.1") roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates @@ -314,15 +314,15 @@ def make_supervised_data_module( ) rank0_print("Loading data...") - train_json = json.load(open(data_args.data_path, "r")) + train_data = load_jsonlines(data_args.data_path) train_dataset = dataset_cls( - train_json, tokenizer=tokenizer, reverse=data_args.reverse + train_data, tokenizer=tokenizer, reverse=data_args.reverse ) if data_args.eval_data_path: - eval_json = json.load(open(data_args.eval_data_path, "r")) + eval_data = load_jsonlines(data_args.eval_data_path) eval_dataset = dataset_cls( - eval_json, tokenizer=tokenizer, reverse=data_args.reverse + eval_data, tokenizer=tokenizer, reverse=data_args.reverse ) else: eval_dataset = None diff --git a/src/train_flash_attn.py b/src/train_flash_attn.py index ca93aec..947a987 100644 --- a/src/train_flash_attn.py +++ b/src/train_flash_attn.py @@ -4,7 +4,7 @@ replace_llama_attn_with_flash_attn() -from .train import train +from src.train import train if __name__ == "__main__": train()