Skip to content

Commit

Permalink
Merge pull request #5922 from hpcaitech/kto
Browse files Browse the repository at this point in the history
[Chat] Add KTO
  • Loading branch information
YeAnbang authored Jul 29, 2024
2 parents ad35a98 + 6fd9e86 commit c8332b9
Show file tree
Hide file tree
Showing 49 changed files with 1,549 additions and 966 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
mkdir sft_data
mkdir prompt_data
mkdir preference_data
mkdir kto_data
./tests/test_data_preparation.sh
./tests/test_train.sh
env:
Expand All @@ -61,3 +62,4 @@ jobs:
SFT_DATASET: ./sft_data
PROMPT_DATASET: ./prompt_data
PREFERENCE_DATASET: ./preference_data
KTO_DATASET: ./kto_data
7 changes: 6 additions & 1 deletion applications/ColossalChat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
- [Limitation for LLaMA-finetuned models](#limitation)
- [Limitation of dataset](#limitation)
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization)
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
Expand Down Expand Up @@ -284,6 +286,9 @@ Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.

## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.

### Inference Quantization and Serving - After Training

We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
Expand Down
21 changes: 12 additions & 9 deletions applications/ColossalChat/benchmarks/benchmark_dpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,33 @@ PROJECT_NAME="dpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/dpo" # Path to benchmark data
DATASET_SIZE=320

TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)

colossalai run --nproc_per_node 4 --master_port 31313 benchmark_dpo.py \
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference


colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--config_file $CONFIG_FILE \
--dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 8 \
--batch_size 4 \
--lr 1e-6 \
--beta 0.1 \
--gamma 0.6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--dataset_size 640 \
--weight_decay 0.01 \
--warmup_steps 60 \
--disable_reference_model \
--length_normalization \
--grad_checkpoint \
--use_flash_attn
51 changes: 51 additions & 0 deletions applications/ColossalChat/benchmarks/benchmark_kto.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4

PROJECT_NAME="kto"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/kto" # Path to benchmark data
DATASET_SIZE=80

TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)

# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type kto


colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_kto.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 2 \
--lr 1e-5 \
--beta 0.1 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
--grad_checkpoint \
--use_flash_attn
Loading

0 comments on commit c8332b9

Please sign in to comment.