Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chat] Add KTO #5922

Merged
merged 13 commits into from
Jul 29, 2024
Merged
2 changes: 2 additions & 0 deletions .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
mkdir sft_data
mkdir prompt_data
mkdir preference_data
mkdir kto_data
./tests/test_data_preparation.sh
./tests/test_train.sh
env:
Expand All @@ -61,3 +62,4 @@ jobs:
SFT_DATASET: ./sft_data
PROMPT_DATASET: ./prompt_data
PREFERENCE_DATASET: ./preference_data
KTO_DATASET: ./kto_data
7 changes: 6 additions & 1 deletion applications/ColossalChat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
- [Limitation for LLaMA-finetuned models](#limitation)
- [Limitation of dataset](#limitation)
- [Alternative Option For RLHF: DPO](#alternative-option-for-rlhf-direct-preference-optimization)
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization)
- [Alternative Option For RLHF: SimPO](#alternative-option-for-rlhf-simple-preference-optimization-simpo)
- [Alternative Option For RLHF: ORPO](#alternative-option-for-rlhf-odds-ratio-preference-optimization-orpo)
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
Expand Down Expand Up @@ -284,6 +286,9 @@ Simple Preference Optimization (SimPO) from this [paper](https://arxiv.org/pdf/2
## Alternative Option For RLHF: Odds Ratio Preference Optimization (ORPO)
Odds Ratio Preference Optimization (ORPO) from this [paper](https://arxiv.org/pdf/2403.07691) is a reference model free alignment method that use a mixture of SFT loss and a reinforcement leanring loss calculated based on odds-ratio-based implicit reward to makes the training more efficient and stable. Read this [README](./examples/README.md) for more information.

## Alternative Option For RLHF: Kahneman-Tversky Optimization (KTO)
We support the method introduced in the paper [KTO:Model Alignment as Prospect Theoretic Optimization](https://arxiv.org/pdf/2402.01306) (KTO). Which is a aligment method that directly maximize "human utility" of generation results. Read this [README](./examples/README.md) for more information.

### Inference Quantization and Serving - After Training

We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
Expand Down
21 changes: 12 additions & 9 deletions applications/ColossalChat/benchmarks/benchmark_dpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,33 @@ PROJECT_NAME="dpo"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/dpo" # Path to benchmark data
DATASET_SIZE=320

TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)

colossalai run --nproc_per_node 4 --master_port 31313 benchmark_dpo.py \
# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type preference


colossalai run --nproc_per_node 4 --master_port 31313 ../examples/training_scripts/train_dpo.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--config_file $CONFIG_FILE \
--dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 8 \
--batch_size 4 \
--lr 1e-6 \
--beta 0.1 \
--gamma 0.6 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--dataset_size 640 \
--weight_decay 0.01 \
--warmup_steps 60 \
--disable_reference_model \
--length_normalization \
--grad_checkpoint \
--use_flash_attn
51 changes: 51 additions & 0 deletions applications/ColossalChat/benchmarks/benchmark_kto.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/bin/bash
set_n_least_used_CUDA_VISIBLE_DEVICES() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv |
tail -n +2 |
nl -v 0 |
tee /dev/tty |
sort -g -k 2 |
awk '{print $1}' |
head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 4

PROJECT_NAME="kto"
PARENT_CONFIG_FILE="./benchmark_config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
BENCHMARK_DATA_DIR="./temp/kto" # Path to benchmark data
DATASET_SIZE=80

TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
FULL_PROJECT_NAME="${PROJECT_NAME}-${TIMESTAMP}"
declare -a dataset=(
$BENCHMARK_DATA_DIR/arrow/part-0
)

# Generate dummy test data
python prepare_dummy_test_dataset.py --data_dir $BENCHMARK_DATA_DIR --dataset_size $DATASET_SIZE --max_length 2048 --data_type kto


colossalai run --nproc_per_node 2 --master_port 31313 ../examples/training_scripts/train_kto.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--dataset ${dataset[@]} \
--plugin "zero2_cpu" \
--max_epochs 1 \
--accumulation_steps 1 \
--batch_size 2 \
--lr 1e-5 \
--beta 0.1 \
--mixed_precision "bf16" \
--grad_clip 1.0 \
--max_length 2048 \
--weight_decay 0.01 \
--warmup_steps 60 \
--grad_checkpoint \
--use_flash_attn
Loading
Loading