Skip to content

YiyangZhou/CSR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

54 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Calibrated Self-Rewarding Vision Language Models

Yiyang Zhou*, Zhiyuan Fan*, Dongjie Cheng*, Sihan Yang, Zhaorun Chen, Chenhang Cui, Xiyao Wang, Yun Li, Linjun Zhang, Huaxiu Yao

Hugging Face

[Project page]

Citation: If you find this repo useful for your research, please consider citing the paper

@article{zhou2024calibrated,
  title={Calibrated Self-Rewarding Vision Language Models},
  author={Zhou, Yiyang and Fan, Zhiyuan and Cheng, Dongjie and Yang, Sihan and Chen, Zhaorun and Cui, Chenhang and Wang, Xiyao and Li, Yun and Zhang, Linjun and Yao, Huaxiu},
  journal={arXiv preprint arXiv:2405.14622},
  year={2024}
}

Table of Contents

About CSR


Framework of Calibrated Self-Rewarding (CSR)

Existing methods use additional models or human annotations to curate preference data and enhance modality alignment through preference optimization. These methods are resource-intensive and may not effectively reflect the target LVLM’s preferences, making the curated preference data easily distinguishable. To address these challenges, we proposes the Calibrated Self- Rewarding (CSR), which enables the model to self-improve by iteratively generating candidate responses, evaluating the reward for each response, and curating preference data for fine-tuning. In reward modeling, a step-wise strategy is adopted, and visual constraints are incorporated into the self-rewarding process to emphasize visual input.


Left: Different parameter sizes of LLaVA 1.5 can enhance their learning through CSR iterations. Right: The change in image relevance scores before and after employing CSR.

Through the online CSR process, the model continuously enhances its performance across various benchmarks and improves the overall relevance scores of its responses to visual inputs. Additionally, it reduces the gap between rejected responses and chosen responses, thereby improving the model's performance lower bound.

Installation

The build process based on LLaVA 1.5:

  1. Clone this repository and navigate to LLaVA folder
git clone https://github.com/haotian-liu/LLaVA.git
cd LLaVA
git clone https://github.com/YiyangZhou/CSR.git
  1. Install Package
conda create -n csr python=3.10 -y
conda activate csr
pip install --upgrade pip
pip install -e .
  1. Install additional packages for training cases
pip install -e ".[train]"
pip install flash-attn --no-build-isolation
  1. Install trl package
pip install trl
  1. Modify the TRL library adjust DPO for LVLMs
cd *your conda path*/envs/csr/lib/python3.10/site-packages/trl/trainer/
# Replace dop_trainer.py with dop_trainer.py in the 'train_csr' folder.
  1. Modify the parent class of llava_trainer
cd ./LLaVA/llava/train

# Modify llava_trainer.py as follows:

# from trl import DPOTrainer
# ...
# ...
# ...
# class LLaVATrainer(DPOTrainer):

Instruction

Before starting, you need to:

(1) modify the path in './CSR/scripts/run_train.sh' to your own path.

(2) If you are using wandb, you need to enter your key in './CSR/train_csr/train_dpo_lora.py' by filling in 'wandb.login(key="your key")' with your key.

(3) Download the image data from the COCO website into './data/images/'(or you can prepare your own images and prompt data).

Step 1. Construct Preference Data.

First, prepare the COCO-2014 train images in the './data/images/'. Then complete the following steps in sequence.

cd ./CSR/inference_csr
bash ./step1.sh
bash ./step2.sh
bash ./step3.sh

You now have the preference dataset. This process takes a long time. We provide our preference datasets in huggingface.

Step 2. Direct Preference Optimization (DPO).

bash ./CSR/scripts/run_train.sh

Step 3. Iterative Learning.

After completing a round of CSR training, you need to merge the current LoRA checkpoint. Use the merged checkpoint as the base model and proceed with Step 1 and Step 2 sequentially.

python ./scripts/merge_lora_weights.py --model-path "your LoRA checkpoint path" --model-base "your llava 1.5 checkpoint path --> your Iter-1 path --> your Iter-2 path ...." --save-model-path "xxx"

Data and Models

We provide CSR training data and model weights on HuggingFace. Please refer to the Instruction for usage.

Dataset Download Model (7B) Download Model (13B) Download
CSR-iter0 🤗 HuggingFace CSR-7B-iter1 🤗 HuggingFace CSR-13B-iter1 🤗 HuggingFace
CSR-iter1 🤗 HuggingFace CSR-7B-iter2 🤗 HuggingFace CSR-13B-iter2 🤗 HuggingFace
CSR-iter2 🤗 HuggingFace CSR-7B-iter3 🤗 HuggingFace CSR-13B-iter3 🤗 HuggingFace

The prompt dataset and mapping files between llava and hf-llava are available in './CSR/inference_csr/data'.

Evaluation

Here are two convenient ways to perform evaluations:

  1. Use the eval scripts provided in LLaVA.

  2. Utilize lmms-eval, a general evaluation platform.

  3. CHAIR metrics in LURE.

Acknowledgement

  • This repository is built upon LLaVA!
  • We thank the Center for AI Safety for supporting our computing needs. This research was supported by Cisco Faculty Research Award.

About

[NeurIPS 2024] Calibrated Self-Rewarding Vision Language Models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published