Skip to content

Holipori/Medical-CXR-VQA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 

Repository files navigation

Medical-CXR-VQA

Medical-CXR-VQA is an LLM-constructed large-scale chest x-ray dataset for the medical visual question answering task. This repository provides the code for generating the Medical-CXR-VQA dataset, as proposed in our paper, "Interpretable medical image Visual Question Answering via multi-modal relationship graph learning."

For more information about the dataset and the method, please refer to our paper.

For the code of our multi-modal relationship graph learning method, please refer to MMRGL .

The Medical-CXR-VQA dataset is now available on Physionet

Data

After downloading from Physionet, put below files into code/data.

  • medical_cxr_vqa_questions.csv: The generated answers and questions for the Medical-CXR-VQA dataset.

  • all_diseases.json: The key information of all the reports generated by the fine-tuned LLaMA 2.

  • mimic_all.csv: All metadata related to the CXR studies.

  • all_diseases_gpt4_100.json: This is the 100 examples(key information) generated by GPT-4, used for fine-tuning LLaMA 2.

Libs

Please put below files into code/libs.

  • disease_lib_llm.csv: The initial disease name library used for generating questions and answers.

  • level_lib.csv: The initial level library used for generating questions and answers.

  • location_lib.csv: The initial location library used for generating questions and answers.

  • type_lib.csv: The initial type library used for generating questions and answers.

  • postlocation_lib.csv: The initial postlocation library used for generating questions and answers.

  • position_change.csv: The initial position change library used for generating questions and answers.

  • entity_dict.json: Disease names with appearance frequencies in the KeyInfo set.

  • type_dict.json: Disease types with appearance frequencies in the KeyInfo set.

  • level_dict.json: Disease levels with appearance frequencies in the KeyInfo set.

  • location_dict.json: Disease locations with appearance frequencies in the KeyInfo set.

File explanation

Below are the files provided in code/data in this code.

  1. "system_text.txt" is the system prompt for ChatGPT.
  2. "simple_system_text.txt" is the simplified system prompt used for fine-tuned LLama 2.
  3. "user_text.txt" is the user input for ChatGPT. In our case, this is the input report. This file is used for function_testing.py only.
  4. "id100.pkl": This file stores the IDs for 100 examples that have been extracted for doctor evaluation. It is used for comparison between different dataset construction methods on the same data.

Steps to generate dataset

Please be aware that the generated dataset may not be exactly the same as our provided one due to randomness. The code we are providing here is for reference purposes.

Firstly,

cd code

1. (Optional) prepare 100 training data.

This step can be skipped because we will provide the annotated golden set for fine-tuning Llama2. The code for this step is still provided here for reference.

The generated all_diseases_gpt4_100.json is provided in Physionet dataset.

1-1 Prerequisite

  1. Azure OpenAI Service access
    • According to the PhysioNet Credentialed Data Use Agreement, it is prohibited to share access to the data with third parties, including sending it through APIs provided by companies like OpenAI, or using it in online platforms like ChatGPT. Therefore, Azure OpenAI Service is suggested.
  2. Download mimic-cxr-report.zip from MIMIC-CXR database

1.2. Preparing GPT-4 generated training data:

python main.py  --model gpt-4 --system_text system_text.txt --select_num 100 --output_name data/all_diseases_gpt4_100.json

Parameter explanations:

  • --model is the model name. The default value is gpt-4. You can also use gpt-35-turbo. Please note that the model_name here should be the same as your deployment name in your Azure portal, not the model name. In my case, "gpt-4" and "gpt-35-turbo" are the deployment names of my two models.
  • --select_num defines the number of examples to extract.
  • --output_name defines the name and the path to the output file.

then run preprocess_data() in fine_tune_llama.py to generate the training data(gpt4_100_data_effective_finetune_llama2.json) for fine-tuning Llama 2.

2. set up LlaMa-Factory

Please refer to LlaMa-Factory for installation.

Next,

  1. Download Llama 2 checkpoint. Please refer to this link. Then store it into the path LlaMa-Factory/model_checkpoints/llama-2-70b-chat
  2. Move the provided ds_config to LlaMa-Factory root derectory.
  3. Modify dataset_info.json in LlaMa-Factory/data by adding the defination for the newly created fine-tuning dataset. The format is shown below. The file_name needs to be compatible with the output name generated using GPT-4.
  "gpt4_100": {
    "file_name": "gpt4_100_data_effective_finetune_llama2.json",
    "columns": {
      "prompt": "query",
      "query": "",
      "response": "output",
      "history": ""
    }
  }

3. fine-tune the model using the following command

deepspeed --num_gpus 6 --master_port=9901 src/train_bash.py \
    --deepspeed ds_config.json \
    --stage sft \
    --model_name_or_path ../model_checkpoints/llama-2-70b-chat \
    --do_train \
    --dataset gpt4_100 \
    --template llama2 \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --output_dir ../model_checkpoints/llama_finetune_gpt4_100 \
    --overwrite_cache \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 1e-3 \
    --num_train_epochs 3.0 \
    --plot_loss \
    --fp16 \
    --overwrite_output_dir
  • --dataset: the dataset name defined in Step 2 dataset_info.json
  • --num_gpus: the number of GPUs used for fine-tuning.

Here, we provide our fine-tuned Llama2 checkpoint at this link: llama_finetune_gpt4_100. Please unzip it before using.

4. combine the fine-tuned model with the original model using the following command: (need to change the model path)

python src/export_model.py \
    --model_name_or_path ../model_checkpoints/llama-2-70b-chat \
    --template llama2 \
    --finetuning_type lora \
    --checkpoint_dir ../model_checkpoints/llama_finetune_gpt4_100 \
    --output_dir ../model_checkpoints/llama_finetune_gpt4_100_output

5. Generate entire dataset using fine-tuned model

python main.py  --model llama_finetune_gpt4_100_output --output_name data/all_diseases_chatgptRaw.json

6. Follow up and post-processing

post-processing: llama2_postprocessing.py

python llama2_postprocessing.py --input_path data/all_diseases_chatgptRaw.json --output_path data/all_diseases_standardized.json

Follow-up question: follow_up_gen.py

python follow_up_gen.py --model_name llama_finetune_gpt4_100_output --raw_file data/all_diseases_standardized.json --followup_file data/all_diseases_standardized_fu.json
  • --raw_file: path to input file
  • --followup_file: path to output file

These two steps can be alternately repeated.

7. Question generation

python question_gen.py --json_path <path_to_the_final_all_diseases_json>

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages