Fusion of source and target representations in LoRA adapters inserted within the query and value matrices of attention modules. The representations are fused in the adapter bottlenecks
1. Clone the repository
git clone https://github.com/*/FLARE
Ensure you have the required dependencies installed. Check the requirements.txt
file for details.
2. Task Fine-tuning on English Data
Fine-tune a pretrained language model (e.g., Gemma 2) on English task data such as XNLI:
bash train_task_ft.sh
3. Machine Translation
Translate XNLI training and evaluation data from English to the target langauge (e.g., Spanish) and translate the test data from Spanish to English.
bash run_translate.sh
4. Cross-lingual Transfer with FLARE
Adapt the fine-tuned Gemma 2 model from English XNLI to Spanish:
bash train_flare.sh
Supported Datasets
Datasets | Links |
---|---|
XNLI | Paper, Data |
NusaX | Paper, Data |
TyDiQA | Paper, Data |
Supported language models
Model | Size | Links |
---|---|---|
XLMR Large | 0.6 B | Paper, Model Card |
mT5 XL | 3.7 B | Paper, Model Card |
Llama 3.1 | 8 B | Paper, Model Card |
Gemma 2 | 9 B | Paper, Model Card |
Note
Llama 3 and Gemma 2 are loaded with 4 bit quantization and trained with LoRAs injected in all linear layers Dettmers et al., 2023.
We provide an overview of the supported input parameters in
run_task_ft.py
for initial task adaptation to English task data.
task
: Specifies the task/dataset, e.g., "xnli", "nusax", or "tydiqa".lang
: Sets the language (default: "en").output_dir
: Directory to save checkpoints (default: "checkpoints").path_data
: Data Directory.plm
: Identifies the pretrained language model to use, options include "gemma2-9b" and "llama3.1-8b".adapter
: Adapter type , e.g., "lora".lora_r
andlora_alpha
: LoRA configuration parameters (r=8, alpha=16 by default).
We provide an overview of the supported input parameters in
run.py
for cross-lingual transfer of the task fine-tuned model with FLARE (also supports FLARE FuseMT, regular LoRA, and input-level fusion):
task
: Specifies the task/dataset, e.g., "xnli", "nusax", or "tydiqa".source_lang
andtarget_lang
: Sets the source and target language, e.g., "en" and "es".output_dir
: Directory to save checkpoints (default: "checkpoints").path_data
: Data Directory, default "translations".path_mt
: Provides the path to the machine translated data, default "translations".load_ckpt
: Path to checkpoint directory, e.g., "{output_dir}/{task}-{plm}-{source_lang}-{seed}"plm
: Identifies the pretrained language model to use, options include "gemma2-9b" and "llama3.1-8b".adapter
: Adapter type , e.g., "lora".lora_r
andlora_alpha
: LoRA configuration parameters (r=8, alpha=16 by default).translate-test
: Sets the evaluation mode "translate-test".translate-train
: Sets the evaluation mode "translate-train" if specified.eval_zs
: Sets the evaluation mode "zero-shot" if specified.
FLARE parameters:
fusion_fn
: Select a supported fusion fucntion, including "add", "add_relu", "mul" or "cross-attention".fuse_mt
: Sets the fusion mode to "FLARE MT" with MT encoder representations used as source language inputs.mt_model
: Select a supported MT model including "nllb-600m", "nllb-3.3b".
Input-level Fusion:
input_fusion
: Sets the fusion mode to "input-level fusion" with source and target language inputs concatenated. Note: ensuremax_length
is adjusted accordingly.