Python code for the paper "Let’s Fuse Step by Step: A Generative Fusion Decoding Algorithm with LLMs for Multi-modal Text Recognition" by Chan-Jan Hsu*, Yi-Chang Chen*, Feng-Ting Liao, Pei-Chen Ho, Yu-Hsiang Wang, Po-Chun Hsu, Da-shan Shiu
*Equal contribution
We introduce "Generative Fusion Decoding" (GFD), a novel shallow fusion framework, utilized to integrate Large Language Models (LLMs) into multi-modal text recognition systems such as automatic speech recognition (ASR) and optical character recognition (OCR). We derive the formulas necessary to enable GFD to operate across mismatched token spaces of different models by mapping text token space to byte token space, enabling seamless fusion during the decoding process. The framework is plug-and-play, compatible with various auto-regressive models, and does not require re-training for feature alignment, thus overcoming limitations of previous fusion techniques. We highlight three main advantages of GFD: First, by simplifying the complexity of aligning different model sample spaces, GFD allows LLMs to correct errors in tandem with the recognition model, reducing computation latencies. Second, the in-context learning ability of LLMs is fully capitalized by GFD, increasing robustness in long-form speech recognition and instruction aware speech recognition. Third, GFD enables fusing recognition models deficient in Chinese text recognition with LLMs extensively trained on Chinese. Our evaluation demonstrates that GFD significantly improves performance in ASR and OCR tasks, with ASR reaching state-of-the-art in the NTUML2021 benchmark. GFD provides a significant step forward in model integration, offering a unified solution that could be widely applicable to leveraging existing pre-trained models through step by step fusion.
- Clone the repository:
git clone https://github.com/mtkresearch/generative-fusion-decoding.git cd generative-fusion-decoding
- Create a python virtual environment:
python -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
- Install the required package:
pip install -r requirements.txt
- Run the setup script:
python setup.py install
We run GFD on 1*A6000 machine
Memory Breakdown: ASR - Whisper Large ~3GB, LLM - Breeze/Mistral ~14GB
🤗You can try it out hassle-free with Kaggle T4 GPUs here!
To run the script, the following three arguments are required:
--model_name
: This argument specifies which type of model to use. There are two options:gfd
: The generative fusion decoding method.
--setting
: The argument specifies the configuration setting for the model. The available settings depend on themodel_name
:asr-zhtw
: The complete version of our method's configuration for testing on a Traditional Chinese sample.asr-zhtw-lmoff
: Uses our custom beam search method on the ASR model, neglecting the output from the LLM (fusing_r = 0) for a Traditional Chinese sample.asr-en
: The complete version of our method's configuration for testing on an English sample.asr-en-lmoff
: Uses our custom beam search method on the ASR model, neglecting the output from the LLM (fusing_r = 0) for an English dataset sample.
--audio_file_path
: The path to the audio file that you want to process.--result_output_path
: The path where the output result will be saved.
Example Usage
python benchmarks/run_single_file.py --model_name gfd --setting asr-zhtw --audio_file_path demo_examples/zh_news.wav --result_output_path output.txt
To run the benchmark dataset, the following four arguments are required:
-
--dataset_name
: Each dataset we tested has a short version name for easy reference. When you runbenchmarks/run_benchmark.py
, the script will automatically download the specified dataset from Hugging Face. Below is a list of short version names of datasets used.- ml-lecture-long-2021: A dataset of long-form audio recordings from NTU 2021 machine learning lectures.
- formosa-long: A dataset of long-form audio recordings in Traditional Chinese.
- fleurs-hk: The Google Fleurs dataset using the split of yue_hant_hk.
- noisy-librispeech-10: Librispeech dataset with noises added to the audio (S/R = 10).
- noisy-librispeech-5: Librispeech dataset with noises added to the audio (S/R = 5).
- atco2: Air Traffic Control Voice Communication dataset.
-
--model_name
: This argument specifies which type of model to use. There are two options:gfd
: The generative fusion decoding method.whisper
: The huggingface whisper generation method.
-
--setting
: The argument specifies the configuration setting for the model. The available settings depend on themodel_name
:For gfd:
asr-zhtw
: The complete version of our method's configuration for testing on the Traditional Chinese dataset.asr-zhtw-lmoff
: Uses our custom beam search method on the ASR model, neglecting the output from the LLM (fusing_r = 0) for Traditional Chinese dataset.asr-en
: he complete version of our method's configuration for testing on the English dataset.asr-en-lmoff
: Uses our custom beam search method on the ASR model, neglecting the output from the LLM (fusing_r = 0) for the English dataset.
For whisper:
whisper-zhtw
: The configuration for the Traditional Chinese dataset.whisper-en
: The configuration for the English dataset.
-
--output_dir
: The argument specifies the path to the directory where the model output will be stored. The outputs of the model will be stored in two subfolders:temp_results
: Stores the result of each sample to a JSON file.ds_result
: Stores the whole dataset along with the model predictions.
Example Usage
Here are some example commands for different configuration:
- Using
gfd
model withasr-zhtw
setting onml-lecture-2021-long
dataset
python benchmarks/run_benchmark.py --dataset_name ml-lecture-2021-long --model_name gfd --setting asr-zhtw --output_dir result/
- Using
whisper
model withwhisper-zhtw
setting onml-lecture-2021-long
dataset
python benchmarks/run_benchmark.py --dataset_name ml-lecture-2021-long --model_name whisper --setting whisper-zhtw --output_dir result/
Using Multiple GPUs
If you have multiple GPUs, you can change the device configuration in the config file.
There are configurations for GFD and Whisper model under config_files/model
, including Traditional Chinese and English for both models.
- GFD:
- Traditional Chinese:
gfd-asr-zhtw.yaml
- English:
gfd-asr-en.yaml
- Traditional Chinese:
- Whisper:
- Traditional Chinese:
whisper-zhtw.yaml
- English:
whisper-en.yaml
- Traditional Chinese:
In config_files/prompt
, it also includes task-specific configurations of Automatic Speech Recognition (ASR) and Language Model (LLM) prompts for gfd
. The naming rule for prompt configuration file is {short version dataset name}_prompt.yaml
.
The general configuration files gfd-asr-zhtw.yaml
and gfd-asr-en.yaml
contain various configuration options. Below are the meanings and choices for each argument, divided into three parts based on the likelihood of needing to reset them.
-
asr_model_path
: Path to the Automatic Speech Recognition (ASR) model for speech recognition. -
llm_model_path
: Path to the Language Model (LLM) for language processing task. -
lang
: Language code for the ASR model, 'en' for English and 'zh' for Chinese. -
asr_device
: Device to run the ASR model on. -
llm_device
: Device to run the LLM on.
-
force_character_mode
: Output mode of characters whenlang == 'zh'
, options include'tc'
for traditional Chinese characters,'sc'
for simplified Chinese characters andNone
for no specific mode specified -
seg_with_overlap
: Default isFalse
. When set toTrue
, the audio will be segmented with a short interval of overlap. If set tofalse
, the audio will be segmented without any overlap. -
fusing_strategy
: Default issimple
. The fusing score of ASR and LLM will be the weighted sum of ASR score and LLM score. score =fusing_r
*llm_score
+1-fusing_r
*asr_score
. -
use_cache
: Default isdynamic
. When set todynamic
, the model will run with key-value (kv) cache enabled, which speeds up the processing, especially for long-from audio. If set toNone
, the kv cache will be disabled. If you are facing memory issues, consider setting it toNone
to release memory. -
fusing_r
: Fusing ratio used in the fusing strategy to combine ASR and LLM outputs. -
asr_attn_implementation
: ASR attention implementation, options including "eager" (manual implementation of the attention), "sdpa" (attention using torch.nn.functional.scaled_dot_product_attention), or "flash_attention_2" (attention using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual "eager" implementation. -
llm_attn_implementation
: LLM attention implementation, options including "eager" (manual implementation of the attention), "sdpa" (attention using torch.nn.functional.scaled_dot_product_attention), or "flash_attention_2" (attention using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual "eager" implementation. -
llm_temp
: LLM temperature parameter to modulate the next token probabilities. -
transcription_cutoff
: Transcription cutoff limit. This argument specified the maximum number of tokens to retain from the previous transcription. If the previous transcription exceeds this limit, it will be truncated to the specified length.
-
repetition_penalty
: The penalty applied to repeated tokens during the generation process. A higher value increases the penalty, making the model less likely to repeat the same tokens. If therepetition_penalty
is greater than therepetition_penalty_threshold
, the penalty is applied. -
repetition_penalty_last
: Repetition penalty for the last tokens, which specifies the number of last tokens to apply the repetition penalty to. -
repetition_penalty_window
: The window size for applying the repetition penalty. The penalty is applied to tokens within this window size from the current token being processed. For example, ifrepetition_penalty_window
is set to50
, the penalty will be applied to tokens within the last 50 tokens from the current token. -
repetition_penalty_threshold
: The threshold for applying the repetition penalty. If therepetition_penalty
is greater than this threshold, the penalty mechanism is activated. -
beam_terminated_strategy
: Beam search termination strategy. The default iswhen_all_ended
, which terminates beam search when all beams reaches the end. -
beam_select_strategy
: Beam selection strategy, options including'best'
which selects the beam with highest score, and'longest'
which selects the beam with longest transcription result -
beam_max_decode_len
: Maximum decode length for beam search, which specifies the maximum length of the decoded sequence during beam search. -
beam_max_len_diff
: Maximum length difference for beam search, which specifies the maximum difference in length between the beams during beam search. -
beam_max_len
: Maximum length for beam search, which specifies the maximum length of the beam search. A default value of-1
means no limit. -
beam_min_len
: Minimum length for beam search. -
logprob_min
: Minimum log probability for the LLM output.
After running the model on the benchmark dataset, you can evaluate the result by calculating the Mixed Error Rates (MER) using the provided benchmarks/calculate_mer.py
script. The script requireds the following arguments:
--dataset_name
: The short version name of the benchmark dataset that you want to evalute.--output_dir
: The output directory that stores the output from the model.
Example Usage
python benchmarks/calculate_mer.py --dataset_name ml-lecture-2021-long --output_dir result/
The table below shows the comparison of each method on multiple datasets:
Dataset | GFD | GFD Ablation* | Whisper(5beams) |
---|---|---|---|
NTUML2021-long | 6.05 | 6.09 | 9.56 |
FormosaSpeech-long | 20.37 | 22.35 | 23.78 |
Fleurs-HK | 5.91 | 7.06 | 6.87 |
Librispeech-Noise (S/R = 10) | 5.07 | 5.33 | 5.16 |
Librispeech-Noise (S/R = 5) | 7.09 | 7.37 | 7.28 |
*In this setting, we set fusing_r = 0
, which corresponds to running whisper with our custom beam search algorithm. Both GFD Ablation and Whisper are baselines of GFD.
Dataset | GFD | GFD | GFD Ablation | GFD Ablation |
---|---|---|---|---|
ASR prompting | yes | no | yes | no |
LLM prompting | yes | yes | NA | NA |
ATCO-2 | - | - | 31.48 / 42.68** | - |
** The former score is computed using the results processed with Whisper EnglishTextNormalizer. The latter score is derived from transcription results that are only converted to lowercase without further normalization They correspond to the Norm and Raw column in the paper respectively.
Warning: This project uses tokenizers with custom tokenizer functions mostly to deal with byte string tokenizations, and has only been tested with the Mistral and Breeze models. Using other models may result in errors or unexpected behavior. Please ensure compatibility before using it with other models.
If you like our work, please site:
@article{hsu2024let,
title={Let's Fuse Step by Step: A Generative Fusion Decoding Algorithm with LLMs for Multi-modal Text Recognition},
author={Hsu, Chan-Jan and Chen, Yi-Chang and Liao, Feng-Ting and Ho, Pei-Chen and Wang, Yu-Hsiang and Hsu, Po-Chun and Shiu, Da-Shan},
journal={arXiv preprint arXiv:2405.14259},
year={2024}
}