Skip to content

Commit

Permalink
update ds config, update training scripts for seed and backward model…
Browse files Browse the repository at this point in the history
… sft
  • Loading branch information
Spico197 committed Aug 15, 2023
1 parent 7a60f84 commit 61912db
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 42 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ debug.json
data/seed/*.jsonl
data/unlabelled/*.jsonl
nohup.out
outputs
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ $ python data/unlabelled/falcon_refinedweb.py

### Train Backward Model $M_{yx}$

- 8 * A100 40GB
- bf16
- gradient checkpointing
- deepspeed stage 2

```bash
$ bash scripts/train_backward_Myx.sh
```

### Self-Augmentation via $M_{yx}$

Expand Down
2 changes: 1 addition & 1 deletion conf/ds_zero1default.json → conf/ds_zero2default.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"enabled": true
},
"zero_optimization": {
"stage": 1
"stage": 2
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytorch-rex==0.1.9
datasets==2.14.1
tqdm==4.65.0
fschat[model_worker,train,webui]==0.2.24
deepspeed==0.10.0
49 changes: 15 additions & 34 deletions scripts/train_backward_Myx.sh
Original file line number Diff line number Diff line change
@@ -1,39 +1,19 @@
#!/usr/bin/bash

#SBATCH --job-name=backward
#SBATCH --output=logs/%x-%j.log
#SBATCH --error=logs/%x-%j.log

#SBATCH --partition=Partition
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=32
#SBATCH --mem=256G
#SBATCH -x SH-IDCA1404-10-140-54-116

#SBATCH --nodes=1
#SBATCH --gres=gpu:4


source ~/anaconda3/bin/activate torch

num_nodes=1 # should match with --nodes
num_gpu_per_node=4 # should match with --gres
num_nodes=1
num_gpu_per_node=8

bsz=32
output_dir="outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID"
output_dir="outputs/backward"
bsz_per_dev=$(echo "${bsz} / ${num_nodes} / ${num_gpu_per_node}" | bc)

srun torchrun \
torchrun \
--nnodes ${num_nodes} \
--nproc_per_node ${num_gpu_per_node} \
--node_rank $SLURM_NODEID \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node:29518 \
src/train_flash_attn.py \
-m src.train_flash_attn \
--reverse \
--deepspeed conf/ds_zero1default.json \
--model_name_or_path ~/Llama-2-7b-hf \
--deepspeed conf/ds_zero2default.json \
--model_name_or_path /home/zhutong/Llama-2-7b-hf \
--data_path data/seed/seed.jsonl \
--per_device_train_batch_size ${bsz_per_dev} \
--per_device_eval_batch_size ${bsz_per_dev} \
Expand All @@ -44,18 +24,19 @@ srun torchrun \
--final_lr "9e-6" \
--weight_decay 0.1 \
--max_grad_norm 1.0 \
--evaluation_strategy "no" \
--logging_strategy steps \
--logging_steps 10 \
--save_strategy steps \
--save_total_limit 2 \
--save_steps 100 \
--logging_steps 1 \
--save_strategy epoch \
--save_total_limit 3 \
--output_dir ${output_dir} \
--overwrite_output_dir \
--ddp_timeout 30000 \
--logging_first_step True \
--bf16 \
--torch_dtype bfloat16 \
--bf16 True \
--tf32 True \
--ddp_find_unused_parameters False \
--gradient_checkpointing \
--report_to none \
--log_level info
--log_level info \
--lazy_preprocess True
62 changes: 62 additions & 0 deletions scripts/train_backward_Myx_slurm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/bash

#SBATCH --job-name=backward
#SBATCH --output=logs/%x-%j.log
#SBATCH --error=logs/%x-%j.log

#SBATCH --partition=Partition
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=32
#SBATCH --mem=256G
#SBATCH -x SH-IDCA1404-10-140-54-116

#SBATCH --nodes=1
#SBATCH --gres=gpu:8


source ~/anaconda3/bin/activate torch

num_nodes=1 # should match with --nodes
num_gpu_per_node=8 # should match with --gres

bsz=32
output_dir="outputs/$SLURM_JOB_NAME-$SLURM_JOB_ID"
bsz_per_dev=$(echo "${bsz} / ${num_nodes} / ${num_gpu_per_node}" | bc)

srun torchrun \
--nnodes ${num_nodes} \
--nproc_per_node ${num_gpu_per_node} \
--node_rank $SLURM_NODEID \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node:29518 \
src/train_flash_attn.py \
--reverse \
--deepspeed conf/ds_zero2default.json \
--model_name_or_path /home/zhutong/Llama-2-7b-hf \
--data_path data/seed/seed.jsonl \
--per_device_train_batch_size ${bsz_per_dev} \
--per_device_eval_batch_size ${bsz_per_dev} \
--num_train_epochs 15 \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--learning_rate "1e-5" \
--final_lr "9e-6" \
--weight_decay 0.1 \
--max_grad_norm 1.0 \
--evaluation_strategy "no" \
--logging_strategy steps \
--logging_steps 1 \
--save_strategy epoch \
--save_total_limit 3 \
--output_dir ${output_dir} \
--overwrite_output_dir \
--ddp_timeout 30000 \
--logging_first_step True \
--bf16 True \
--tf32 True \
--ddp_find_unused_parameters False \
--gradient_checkpointing \
--report_to none \
--log_level info \
--lazy_preprocess True
41 changes: 41 additions & 0 deletions scripts/train_seed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/bash

num_nodes=1
num_gpu_per_node=8

bsz=32
output_dir="outputs/seed_model"
bsz_per_dev=$(echo "${bsz} / ${num_nodes} / ${num_gpu_per_node}" | bc)

torchrun \
--nnodes ${num_nodes} \
--nproc_per_node ${num_gpu_per_node} \
-m src.train_flash_attn \
--deepspeed conf/ds_zero2default.json \
--model_name_or_path /home/zhutong/Llama-2-7b-hf \
--data_path data/seed/seed.jsonl \
--per_device_train_batch_size ${bsz_per_dev} \
--per_device_eval_batch_size ${bsz_per_dev} \
--num_train_epochs 15 \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--learning_rate "1e-5" \
--final_lr "9e-6" \
--weight_decay 0.1 \
--max_grad_norm 1.0 \
--evaluation_strategy "no" \
--logging_strategy steps \
--logging_steps 1 \
--save_strategy epoch \
--save_total_limit 3 \
--output_dir ${output_dir} \
--overwrite_output_dir \
--ddp_timeout 30000 \
--logging_first_step True \
--bf16 True \
--tf32 True \
--ddp_find_unused_parameters False \
--gradient_checkpointing \
--report_to none \
--log_level info \
--lazy_preprocess True
12 changes: 6 additions & 6 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
https://github.com/lm-sys/FastChat/blob/main/LICENSE
"""

import json
import math
import pathlib
from dataclasses import dataclass, field
Expand All @@ -14,6 +13,7 @@

import torch
import transformers
from rex.utils.io import load_jsonlines
from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
from torch.optim.lr_scheduler import LambdaLR
Expand Down Expand Up @@ -145,7 +145,7 @@ def preprocess(
sources,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
conv = get_conversation_template("llama-2")
conv = get_conversation_template("vicuna_v1.1")
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# Apply prompt templates
Expand Down Expand Up @@ -314,15 +314,15 @@ def make_supervised_data_module(
)
rank0_print("Loading data...")

train_json = json.load(open(data_args.data_path, "r"))
train_data = load_jsonlines(data_args.data_path)
train_dataset = dataset_cls(
train_json, tokenizer=tokenizer, reverse=data_args.reverse
train_data, tokenizer=tokenizer, reverse=data_args.reverse
)

if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_data = load_jsonlines(data_args.eval_data_path)
eval_dataset = dataset_cls(
eval_json, tokenizer=tokenizer, reverse=data_args.reverse
eval_data, tokenizer=tokenizer, reverse=data_args.reverse
)
else:
eval_dataset = None
Expand Down
2 changes: 1 addition & 1 deletion src/train_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

replace_llama_attn_with_flash_attn()

from .train import train
from src.train import train

if __name__ == "__main__":
train()

0 comments on commit 61912db

Please sign in to comment.