Skip to content

An open-source implementaion for fine-tuning Pixtral by MistralAI.

License

Notifications You must be signed in to change notification settings

2U1/Pixtral-Finetune

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Fine-tuning Pixtral

This repository contains a script for training Trnasformers compatible Pixtral-12b.

However the model only supports batch size=1. So it could take a long time to fine tune.

Other projects

[Phi3-Vision Finetuning]
[Llama3.2-Vision Finetuning]
[Qwen2-VL Finetuning]
[Molmo Finetune]

Table of Contents

Supported Features

  • Deepspeed
  • LoRA/QLoRA
  • Full-finetuning
  • Enable finetuning vision_model while using LoRA.
  • Disable/enable Flash Attention 2
  • Multi-image and video training
  • Training optimized with liger kernel

Installation

Install the required packages using environment.yaml.

Using environment.yaml

conda env create -f environment.yaml
conda activate pixtral
pip install flash-attn==2.5.8 --no-build-isolation

Note: You should install flash-attn after installing the other packages.

Dataset Preparation

The script requires a dataset formatted according to the LLaVA specification. The dataset should be a JSON file where each entry contains information about conversations and images. Ensure that the image paths in the dataset match the provided --image_folder.

When using a multi-image dataset, the image tokens should all be <image>, and the image file names should have been in a list. Please see the example below and follow format your data.

Example for single image dataset
[
  {
    "id": "000000033471",
    "image": "000000033471.jpg",
    "conversations": [
      {
        "from": "human",
        "value": "<image>\nWhat are the colors of the bus in the image?"
      },
      {
        "from": "gpt",
        "value": "The bus in the image is white and red."
      },
      {
        "from": "human",
        "value": "What feature can be seen on the back of the bus?"
      },
      {
        "from": "gpt",
        "value": "The back of the bus features an advertisement."
      },
      {
        "from": "human",
        "value": "Is the bus driving down the street or pulled off to the side?"
      },
      {
        "from": "gpt",
        "value": "The bus is driving down the street, which is crowded with people and other vehicles."
      }
    ]
  }
  ...
]
Example for multi image dataset
[
  {
    "id": "000000033471",
    "image": ["000000033471.jpg", "000000033472.jpg"],
    "conversations": [
      {
        "from": "human",
        "value": "<image>\n<image>\nIs the perspective of the camera differnt?"
      },
      {
        "from": "gpt",
        "value": "Yes, It the perspective of the camera is different."
      }
    ]
  }
  ...
]
Example for video dataset
[
  {
    "id": "sample1",
    "video": "sample1.mp4",
    "conversations": [
      {
        "from": "human",
        "value": "<video>\nWhat is going on in this video?"
      },
      {
        "from": "gpt",
        "value": "A man is walking down the road."
      }
    ]
  }
  ...
]

Note: Officially pixtral dosen't support the video, but it supports multi-image so you could just use the video as a sequential of frames.

Training

Note: With the mixed-dataset (e.g. some data in a batch have images while some don't) It only supports with zero2.

To run the training script, use the following command:

Full Finetuning

bash scripts/finetune.sh

Full Finetuning with 8-bit

bash scripts/finetune_8bit.sh

This script will finetune the model with 8bit-adamw and fp8 model dtype. If you run out of vram, you could use this.

Finetune with LoRA

If you want to train only the language model with LoRA and perform full training for the vision model:

bash scripts/finetune_lora.sh

If you want to train both the language model and the vision model with LoRA:

bash scripts/finetune_lora_vision.sh

IMPORTANT: If you want to tune the embed_token with LoRA, You need to tune lm_head together. Note: Freezing LLM would only work without LoRA (including vision_model LoRA).

Training arguments
  • --deepspeed (str): Path to DeepSpeed config file (default: "scripts/zero2.json").
  • --data_path (str): Path to the LLaVA formatted training data (a JSON file). (Required)
  • --image_folder (str): Path to the images folder as referenced in the LLaVA formatted training data. (Required)
  • --model_id (str): Path to the Pixtral model. (Required)
  • --output_dir (str): Output directory for model checkpoints
  • --num_train_epochs (int): Number of training epochs (default: 1).
  • --per_device_train_batch_size (int): Training batch size per GPU per forwarding step.
  • --gradient_accumulation_steps (int): Gradient accumulation steps (default: 4).
  • --freeze_vision_tower (bool): Option to freeze vision_model (default: False).
  • --freeze_llm (bool): Option to freeze LLM (default: False).
  • --tune_merger (bool): Option to tune projector (default: True).
  • --num_lora_modules (int): Number of target modules to add LoRA (-1 means all layers).
  • --vision_lr (float): Learning rate for vision_model.
  • --merger_lr (float): Learning rate for merger(projector).
  • --learning_rate (float): Learning rate for language module.
  • --max_num_frames (int): Maxmimum frames for video dataset (default: 10)
  • --bf16 (bool): Option for using bfloat16.
  • --fp16 (bool): Option for using fp16.
  • --min_pixels (int): Option for minimum input tokens.
  • --max_pixles (int): OPtion for maximum maxmimum tokens.
  • --lora_namespan_exclude (str): Exclude modules with namespans to add LoRA.
  • --max_seq_length (int): Maximum sequence length (default: 32K).
  • --bits (int): Quantization bits (default: 16).
  • --disable_flash_attn2 (bool): Disable Flash Attention 2.
  • --report_to (str): Reporting tool (choices: 'tensorboard', 'wandb', 'none') (default: 'tensorboard').
  • --logging_dir (str): Logging directory (default: "./tf-logs").
  • --lora_rank (int): LoRA rank (default: 128).
  • --lora_alpha (int): LoRA alpha (default: 256).
  • --lora_dropout (float): LoRA dropout (default: 0.05).
  • --logging_steps (int): Logging steps (default: 1).
  • --dataloader_num_workers (int): Number of data loader workers (default: 4).

Note: The learning rate of vision_model should be 10x ~ 5x smaller than the language_model.

Train with video dataset

You can train the model using a video dataset. However, officially pixtral dosen't support video. So this code processes videos as a sequence of images, so you’ll need to select specific frames and treat them as multiple images for training. You can set LoRA configs and use for LoRA too.

bash scripts/finetune_video.sh

Note: You should adjust max_num_frames based on the available VRAM.

If you run out of vram, you can use zero3_offload instead of zero3. However, using zero3 is preferred.

Merge LoRA Weights

bash scripts/merge_lora.sh

Note: Remember to replace the paths in finetune.sh or finetune_lora.sh with your specific paths. (Also in merge_lora.sh when using LoRA.)

Issue for libcudnn error

Could not load library libcudnn_cnn_train.so.8. Error: /usr/local/cuda-12.1/lib/libcudnn_cnn_train.so.8: undefined symbol: _ZN5cudnn3cnn34layerNormFwd_execute_internal_implERKNS_7backend11VariantPackEP11CUstream_stRNS0_18LayerNormFwdParamsERKNS1_20NormForwardOperationEmb, version libcudnn_cnn_infer.so.8

You could run unset LD_LIBRARY_PATH for this error. You could see this issue

TODO

  • Support batch size > 1

Known Issues

License

This project is licensed under the Apache-2.0 License. See the LICENSE file for details.

Citation

If you find this repository useful in your project, please consider giving a ⭐ and citing:

@misc{Pixtral-Finetuning,
  author = {Yuwon Lee},
  title = {Pixtral-Finetune},
  year = {2024},
  publisher = {GitHub},
  url = {https://github.com/2U1/Pixtral-Finetune}
}

Acknowledgement

This project is based on

  • LLaVA-NeXT: An amazing open-source project of LMM.
  • Pixtral-12B: Transformer compatible version of pixtral-12b

About

An open-source implementaion for fine-tuning Pixtral by MistralAI.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published