Skip to content

Commit

Permalink
Upgrade to transformers==4.36.1; q_align==1.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
haoning.wu committed Jan 4, 2024
1 parent 89df28a commit d3bae04
Show file tree
Hide file tree
Showing 12 changed files with 398 additions and 28 deletions.
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
</div>




<h2>Results</h2>
<div style="width: 75%; text-align: center; margin:auto;">
Expand All @@ -60,6 +59,16 @@
</div>
</div>

## [Important Note!]


Thus, we have modified respective code for mPLUG-Owl2 to adapt it to the newest transformer version, i.e. `transformers==4.36.1`, so that you do not need to create a separate outdated environment while using it alongside other projects. The updated code is no longer compatible with the old-version Q-Align (v1.0.1/v1.0.0, and before), please update to the newest version via the following scripts:

```shell
git pull
pip install -e .
```


## Installation

Expand Down Expand Up @@ -273,7 +282,7 @@ sh scripts/l1_lsvq.sh
- Training OneAlign with IQA datasets, AVA dataset (IAA) and LSVQ dataset (VQA):

```shell
sh scripts/all_.sh
sh scripts/onealign.sh
```

*At least 8\*A6000 GPUs or 4\*A100 GPUs will be enough for the training.*
Expand Down
3 changes: 3 additions & 0 deletions q_align/model/configuration_mplug_owl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -140,6 +141,8 @@ def __init__(
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self._attn_implementation = "flash_attention_2"

super().__init__(
pad_token_id=pad_token_id,
Expand Down
340 changes: 334 additions & 6 deletions q_align/model/modeling_llama2.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions q_align/model/visual_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def custom_forward(*inputs):

class MplugOwlVisionModel(PreTrainedModel):
main_input_name = "pixel_values"
_no_split_modules = ["MplugOwlVisionEncoderLayer"]

def __init__(self, config):
super().__init__(config)
Expand Down Expand Up @@ -754,6 +755,7 @@ def custom_forward(*inputs):


class MplugOwlVisualAbstractorModel(PreTrainedModel):
_no_split_modules = ["MplugOwlVisualAbstractorLayer"]
def __init__(self, config, language_hidden_size):
super().__init__(config)
self.config = config
Expand Down
17 changes: 5 additions & 12 deletions q_align/train/mplug_owl2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
ShardedDDPOption,
logger,
)
from typing import List, Optional
Expand Down Expand Up @@ -154,10 +153,10 @@ def create_optimizer(self):
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
return super().create_optimizer()
#if is_sagemaker_mp_enabled():
# return super().create_optimizer()
#if self.sharded_ddp == ShardedDDPOption.SIMPLE:
# return super().create_optimizer()

opt_model = self.model

Expand Down Expand Up @@ -212,13 +211,7 @@ def create_optimizer(self):
ic(len(optimizer_grouped_parameters[0]['params']),len(optimizer_grouped_parameters[1]['params']))
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)

if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
if True:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
Expand Down
4 changes: 2 additions & 2 deletions q_align/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,9 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
continue
if self.data_args.image_aspect_ratio == 'pad':
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args)
Expand Down
4 changes: 2 additions & 2 deletions q_align/train/train_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.

# Need to call this before importing transformers.
from q_align.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
#from q_align.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn

replace_llama_attn_with_flash_attn()
#replace_llama_attn_with_flash_attn()

from q_align.train.train import train

Expand Down
2 changes: 1 addition & 1 deletion scripts/iqa_iaa.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
LOAD='MAGAer13/mplug-owl2-llama2-7b'

DATA_FILE=playground/data/training_sft/train_iqa_iaa.json
deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \
deepspeed --master_port 25801 q_align/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
--version v1 \
Expand Down
2 changes: 1 addition & 1 deletion scripts/iqa_mix.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
LOAD='MAGAer13/mplug-owl2-llama2-7b'

DATA_FILE=playground/data/training_sft/train_koniq_spaq_kadid.json
deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \
deepspeed --master_port 25801 q_align/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
--version v1 \
Expand Down
2 changes: 1 addition & 1 deletion scripts/iqa_vqa.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
LOAD='MAGAer13/mplug-owl2-llama2-7b'

DATA_FILE=playground/data/training_sft/train_iqa_vqa.json
deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \
deepspeed --master_port 25801 q_align/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
--version v1 \
Expand Down
2 changes: 1 addition & 1 deletion scripts/l1_koniq.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash
LOAD='MAGAer13/mplug-owl2-llama2-7b'

DATA_FILE=playground/data/training_sfttrain_koniq.json
DATA_FILE=playground/data/training_sft/train_koniq.json
deepspeed --master_port 25801 q_align/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
Expand Down
35 changes: 35 additions & 0 deletions scripts/onealign.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
LOAD='MAGAer13/mplug-owl2-llama2-7b'

DATA_FILE=playground/data/training_sft/train_all.json
deepspeed --master_port 25801 q_align/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
--version v1 \
--data_path $DATA_FILE \
--image_folder playground/data/ \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir ./one-align \
--num_train_epochs 2 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1100 \
--save_total_limit 2 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--tune_visual_abstractor True \
--freeze_vision_model False \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb

0 comments on commit d3bae04

Please sign in to comment.