diff --git a/README.md b/README.md
index f847ee0..66a6e98 100644
--- a/README.md
+++ b/README.md
@@ -40,7 +40,6 @@
-
Results
@@ -60,6 +59,16 @@
+## [Important Note!]
+
+
+Thus, we have modified respective code for mPLUG-Owl2 to adapt it to the newest transformer version, i.e. `transformers==4.36.1`, so that you do not need to create a separate outdated environment while using it alongside other projects. The updated code is no longer compatible with the old-version Q-Align (v1.0.1/v1.0.0, and before), please update to the newest version via the following scripts:
+
+```shell
+git pull
+pip install -e .
+```
+
## Installation
@@ -273,7 +282,7 @@ sh scripts/l1_lsvq.sh
- Training OneAlign with IQA datasets, AVA dataset (IAA) and LSVQ dataset (VQA):
```shell
-sh scripts/all_.sh
+sh scripts/onealign.sh
```
*At least 8\*A6000 GPUs or 4\*A100 GPUs will be enough for the training.*
diff --git a/q_align/model/configuration_mplug_owl2.py b/q_align/model/configuration_mplug_owl2.py
index e2e31a6..81362a6 100644
--- a/q_align/model/configuration_mplug_owl2.py
+++ b/q_align/model/configuration_mplug_owl2.py
@@ -117,6 +117,7 @@ def __init__(
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
+ attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
@@ -140,6 +141,8 @@ def __init__(
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
+ self.attention_dropout = attention_dropout
+ self._attn_implementation = "flash_attention_2"
super().__init__(
pad_token_id=pad_token_id,
diff --git a/q_align/model/modeling_llama2.py b/q_align/model/modeling_llama2.py
index 3122bbb..e8c4f04 100644
--- a/q_align/model/modeling_llama2.py
+++ b/q_align/model/modeling_llama2.py
@@ -18,6 +18,7 @@
import transformers
from transformers.models.llama.modeling_llama import *
+from transformers.models.llama.modeling_llama import *
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
@@ -54,9 +55,18 @@ def forward(self, hidden_states, multiway_indices):
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: LlamaConfig):
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
@@ -64,6 +74,7 @@ def __init__(self, config: LlamaConfig):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
+ self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
@@ -182,14 +193,314 @@ def forward(
attn_weights = None
return attn_output, attn_weights, past_key_value
+
+
+class LlamaFlashAttention2(LlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ modality_indicators: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # LlamaFlashAttention2 attention does not support output_attentions
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop("padding_mask")
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states, modality_indicators)
+ value_states = self.v_proj(hidden_states, modality_indicators)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = self._flash_attention_forward(
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+ )
+
+ return attn_output
+
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+class LlamaSdpaAttention(LlamaAttention):
+ """
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from LlamaAttention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ modality_indicators: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ modality_indicators=modality_indicators,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states, modality_indicators)
+ value_states = self.v_proj(hidden_states, modality_indicators)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+LLAMA_ATTENTION_CLASSES = {
+ "eager": LlamaAttention,
+ "flash_attention_2": LlamaFlashAttention2,
+ "sdpa": LlamaSdpaAttention,
+}
class LlamaDecoderLayer(nn.Module):
- def __init__(self, config: LlamaConfig):
+ def __init__(self, config: LlamaConfig, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = MultiwayNetwork(module_provider=partial(
LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
@@ -285,7 +596,7 @@ def model_forward(
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
-
+
seq_length_with_past = seq_length
past_key_values_length = 0
@@ -309,9 +620,24 @@ def model_forward(
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
- )
+
+ if self._use_flash_attention_2:
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self._use_sdpa and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
hidden_states = inputs_embeds
@@ -482,6 +808,8 @@ def causal_model_forward(
def replace_llama_modality_adaptive():
transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
+ transformers.models.llama.modeling_llama.LlamaFlashAttention2 = LlamaFlashAttention2
+ transformers.models.llama.modeling_llama.LlamaSdpaAttention = LlamaSdpaAttention
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
diff --git a/q_align/model/visual_encoder.py b/q_align/model/visual_encoder.py
index 4feaa7f..5bd3468 100644
--- a/q_align/model/visual_encoder.py
+++ b/q_align/model/visual_encoder.py
@@ -383,6 +383,7 @@ def custom_forward(*inputs):
class MplugOwlVisionModel(PreTrainedModel):
main_input_name = "pixel_values"
+ _no_split_modules = ["MplugOwlVisionEncoderLayer"]
def __init__(self, config):
super().__init__(config)
@@ -754,6 +755,7 @@ def custom_forward(*inputs):
class MplugOwlVisualAbstractorModel(PreTrainedModel):
+ _no_split_modules = ["MplugOwlVisualAbstractorLayer"]
def __init__(self, config, language_hidden_size):
super().__init__(config)
self.config = config
diff --git a/q_align/train/mplug_owl2_trainer.py b/q_align/train/mplug_owl2_trainer.py
index 293dcdf..ce50d09 100644
--- a/q_align/train/mplug_owl2_trainer.py
+++ b/q_align/train/mplug_owl2_trainer.py
@@ -9,7 +9,6 @@
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
- ShardedDDPOption,
logger,
)
from typing import List, Optional
@@ -154,10 +153,10 @@ def create_optimizer(self):
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
- if is_sagemaker_mp_enabled():
- return super().create_optimizer()
- if self.sharded_ddp == ShardedDDPOption.SIMPLE:
- return super().create_optimizer()
+ #if is_sagemaker_mp_enabled():
+ # return super().create_optimizer()
+ #if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ # return super().create_optimizer()
opt_model = self.model
@@ -212,13 +211,7 @@ def create_optimizer(self):
ic(len(optimizer_grouped_parameters[0]['params']),len(optimizer_grouped_parameters[1]['params']))
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
- if self.sharded_ddp == ShardedDDPOption.SIMPLE:
- self.optimizer = OSS(
- params=optimizer_grouped_parameters,
- optim=optimizer_cls,
- **optimizer_kwargs,
- )
- else:
+ if True:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
diff --git a/q_align/train/train.py b/q_align/train/train.py
index d109848..0bb55bf 100644
--- a/q_align/train/train.py
+++ b/q_align/train/train.py
@@ -602,9 +602,9 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
continue
if self.data_args.image_aspect_ratio == 'pad':
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
- image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values']
else:
- image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values']
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args)
diff --git a/q_align/train/train_mem.py b/q_align/train/train_mem.py
index 4698657..233184d 100644
--- a/q_align/train/train_mem.py
+++ b/q_align/train/train_mem.py
@@ -3,9 +3,9 @@
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
-from q_align.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
+#from q_align.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
-replace_llama_attn_with_flash_attn()
+#replace_llama_attn_with_flash_attn()
from q_align.train.train import train
diff --git a/scripts/iqa_iaa.sh b/scripts/iqa_iaa.sh
index f64a75b..5187d58 100644
--- a/scripts/iqa_iaa.sh
+++ b/scripts/iqa_iaa.sh
@@ -2,7 +2,7 @@
LOAD='MAGAer13/mplug-owl2-llama2-7b'
DATA_FILE=playground/data/training_sft/train_iqa_iaa.json
-deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \
+deepspeed --master_port 25801 q_align/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
--version v1 \
diff --git a/scripts/iqa_mix.sh b/scripts/iqa_mix.sh
index b155180..d159937 100644
--- a/scripts/iqa_mix.sh
+++ b/scripts/iqa_mix.sh
@@ -2,7 +2,7 @@
LOAD='MAGAer13/mplug-owl2-llama2-7b'
DATA_FILE=playground/data/training_sft/train_koniq_spaq_kadid.json
-deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \
+deepspeed --master_port 25801 q_align/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
--version v1 \
diff --git a/scripts/iqa_vqa.sh b/scripts/iqa_vqa.sh
index 02aa70c..075bdc2 100644
--- a/scripts/iqa_vqa.sh
+++ b/scripts/iqa_vqa.sh
@@ -2,7 +2,7 @@
LOAD='MAGAer13/mplug-owl2-llama2-7b'
DATA_FILE=playground/data/training_sft/train_iqa_vqa.json
-deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \
+deepspeed --master_port 25801 q_align/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
--version v1 \
diff --git a/scripts/l1_koniq.sh b/scripts/l1_koniq.sh
index 79bcde5..2522f23 100644
--- a/scripts/l1_koniq.sh
+++ b/scripts/l1_koniq.sh
@@ -1,7 +1,7 @@
#!/bin/bash
LOAD='MAGAer13/mplug-owl2-llama2-7b'
-DATA_FILE=playground/data/training_sfttrain_koniq.json
+DATA_FILE=playground/data/training_sft/train_koniq.json
deepspeed --master_port 25801 q_align/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
diff --git a/scripts/onealign.sh b/scripts/onealign.sh
new file mode 100644
index 0000000..994ebb2
--- /dev/null
+++ b/scripts/onealign.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+LOAD='MAGAer13/mplug-owl2-llama2-7b'
+
+DATA_FILE=playground/data/training_sft/train_all.json
+deepspeed --master_port 25801 q_align/train/train_mem.py \
+ --deepspeed ./scripts/zero3.json \
+ --model_name_or_path $LOAD \
+ --version v1 \
+ --data_path $DATA_FILE \
+ --image_folder playground/data/ \
+ --image_aspect_ratio pad \
+ --group_by_modality_length True \
+ --bf16 True \
+ --output_dir ./one-align \
+ --num_train_epochs 2 \
+ --per_device_train_batch_size 32 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 1 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 1100 \
+ --save_total_limit 2 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --tune_visual_abstractor True \
+ --freeze_vision_model False \
+ --dataloader_num_workers 4 \
+ --lazy_preprocess True \
+ --report_to wandb
\ No newline at end of file