diff --git a/README.md b/README.md index f847ee0..66a6e98 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,6 @@ -

Results

@@ -60,6 +59,16 @@
+## [Important Note!] + + +Thus, we have modified respective code for mPLUG-Owl2 to adapt it to the newest transformer version, i.e. `transformers==4.36.1`, so that you do not need to create a separate outdated environment while using it alongside other projects. The updated code is no longer compatible with the old-version Q-Align (v1.0.1/v1.0.0, and before), please update to the newest version via the following scripts: + +```shell +git pull +pip install -e . +``` + ## Installation @@ -273,7 +282,7 @@ sh scripts/l1_lsvq.sh - Training OneAlign with IQA datasets, AVA dataset (IAA) and LSVQ dataset (VQA): ```shell -sh scripts/all_.sh +sh scripts/onealign.sh ``` *At least 8\*A6000 GPUs or 4\*A100 GPUs will be enough for the training.* diff --git a/q_align/model/configuration_mplug_owl2.py b/q_align/model/configuration_mplug_owl2.py index e2e31a6..81362a6 100644 --- a/q_align/model/configuration_mplug_owl2.py +++ b/q_align/model/configuration_mplug_owl2.py @@ -117,6 +117,7 @@ def __init__( rope_theta=10000.0, rope_scaling=None, attention_bias=False, + attention_dropout=0.0, **kwargs, ): self.vocab_size = vocab_size @@ -140,6 +141,8 @@ def __init__( self.rope_scaling = rope_scaling self._rope_scaling_validation() self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self._attn_implementation = "flash_attention_2" super().__init__( pad_token_id=pad_token_id, diff --git a/q_align/model/modeling_llama2.py b/q_align/model/modeling_llama2.py index 3122bbb..e8c4f04 100644 --- a/q_align/model/modeling_llama2.py +++ b/q_align/model/modeling_llama2.py @@ -18,6 +18,7 @@ import transformers from transformers.models.llama.modeling_llama import * +from transformers.models.llama.modeling_llama import * from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging @@ -54,9 +55,18 @@ def forward(self, hidden_states, multiway_indices): class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads @@ -64,6 +74,7 @@ def __init__(self, config: LlamaConfig): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta + self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -182,14 +193,314 @@ def forward( attn_weights = None return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + modality_indicators: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states, modality_indicators) + value_states = self.v_proj(hidden_states, modality_indicators) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + modality_indicators: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + modality_indicators=modality_indicators, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states, modality_indicators) + value_states = self.v_proj(hidden_states, modality_indicators) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, layer_idx): super().__init__() self.hidden_size = config.hidden_size self.self_attn = LlamaAttention(config=config) + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = MultiwayNetwork(module_provider=partial( LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps @@ -285,7 +596,7 @@ def model_forward( batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - + seq_length_with_past = seq_length past_key_values_length = 0 @@ -309,9 +620,24 @@ def model_forward( attention_mask = torch.ones( (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) hidden_states = inputs_embeds @@ -482,6 +808,8 @@ def causal_model_forward( def replace_llama_modality_adaptive(): transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention + transformers.models.llama.modeling_llama.LlamaFlashAttention2 = LlamaFlashAttention2 + transformers.models.llama.modeling_llama.LlamaSdpaAttention = LlamaSdpaAttention transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward diff --git a/q_align/model/visual_encoder.py b/q_align/model/visual_encoder.py index 4feaa7f..5bd3468 100644 --- a/q_align/model/visual_encoder.py +++ b/q_align/model/visual_encoder.py @@ -383,6 +383,7 @@ def custom_forward(*inputs): class MplugOwlVisionModel(PreTrainedModel): main_input_name = "pixel_values" + _no_split_modules = ["MplugOwlVisionEncoderLayer"] def __init__(self, config): super().__init__(config) @@ -754,6 +755,7 @@ def custom_forward(*inputs): class MplugOwlVisualAbstractorModel(PreTrainedModel): + _no_split_modules = ["MplugOwlVisualAbstractorLayer"] def __init__(self, config, language_hidden_size): super().__init__(config) self.config = config diff --git a/q_align/train/mplug_owl2_trainer.py b/q_align/train/mplug_owl2_trainer.py index 293dcdf..ce50d09 100644 --- a/q_align/train/mplug_owl2_trainer.py +++ b/q_align/train/mplug_owl2_trainer.py @@ -9,7 +9,6 @@ get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, - ShardedDDPOption, logger, ) from typing import List, Optional @@ -154,10 +153,10 @@ def create_optimizer(self): We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ - if is_sagemaker_mp_enabled(): - return super().create_optimizer() - if self.sharded_ddp == ShardedDDPOption.SIMPLE: - return super().create_optimizer() + #if is_sagemaker_mp_enabled(): + # return super().create_optimizer() + #if self.sharded_ddp == ShardedDDPOption.SIMPLE: + # return super().create_optimizer() opt_model = self.model @@ -212,13 +211,7 @@ def create_optimizer(self): ic(len(optimizer_grouped_parameters[0]['params']),len(optimizer_grouped_parameters[1]['params'])) optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) - if self.sharded_ddp == ShardedDDPOption.SIMPLE: - self.optimizer = OSS( - params=optimizer_grouped_parameters, - optim=optimizer_cls, - **optimizer_kwargs, - ) - else: + if True: self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if optimizer_cls.__name__ == "Adam8bit": import bitsandbytes diff --git a/q_align/train/train.py b/q_align/train/train.py index d109848..0bb55bf 100644 --- a/q_align/train/train.py +++ b/q_align/train/train.py @@ -602,9 +602,9 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: continue if self.data_args.image_aspect_ratio == 'pad': image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) - image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image = processor.preprocess(image, return_tensors='pt')['pixel_values'] else: - image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + image = processor.preprocess(image, return_tensors='pt')['pixel_values'] sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), self.data_args) diff --git a/q_align/train/train_mem.py b/q_align/train/train_mem.py index 4698657..233184d 100644 --- a/q_align/train/train_mem.py +++ b/q_align/train/train_mem.py @@ -3,9 +3,9 @@ # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. # Need to call this before importing transformers. -from q_align.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn +#from q_align.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn -replace_llama_attn_with_flash_attn() +#replace_llama_attn_with_flash_attn() from q_align.train.train import train diff --git a/scripts/iqa_iaa.sh b/scripts/iqa_iaa.sh index f64a75b..5187d58 100644 --- a/scripts/iqa_iaa.sh +++ b/scripts/iqa_iaa.sh @@ -2,7 +2,7 @@ LOAD='MAGAer13/mplug-owl2-llama2-7b' DATA_FILE=playground/data/training_sft/train_iqa_iaa.json -deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \ +deepspeed --master_port 25801 q_align/train/train_mem.py \ --deepspeed ./scripts/zero3.json \ --model_name_or_path $LOAD \ --version v1 \ diff --git a/scripts/iqa_mix.sh b/scripts/iqa_mix.sh index b155180..d159937 100644 --- a/scripts/iqa_mix.sh +++ b/scripts/iqa_mix.sh @@ -2,7 +2,7 @@ LOAD='MAGAer13/mplug-owl2-llama2-7b' DATA_FILE=playground/data/training_sft/train_koniq_spaq_kadid.json -deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \ +deepspeed --master_port 25801 q_align/train/train_mem.py \ --deepspeed ./scripts/zero3.json \ --model_name_or_path $LOAD \ --version v1 \ diff --git a/scripts/iqa_vqa.sh b/scripts/iqa_vqa.sh index 02aa70c..075bdc2 100644 --- a/scripts/iqa_vqa.sh +++ b/scripts/iqa_vqa.sh @@ -2,7 +2,7 @@ LOAD='MAGAer13/mplug-owl2-llama2-7b' DATA_FILE=playground/data/training_sft/train_iqa_vqa.json -deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \ +deepspeed --master_port 25801 q_align/train/train_mem.py \ --deepspeed ./scripts/zero3.json \ --model_name_or_path $LOAD \ --version v1 \ diff --git a/scripts/l1_koniq.sh b/scripts/l1_koniq.sh index 79bcde5..2522f23 100644 --- a/scripts/l1_koniq.sh +++ b/scripts/l1_koniq.sh @@ -1,7 +1,7 @@ #!/bin/bash LOAD='MAGAer13/mplug-owl2-llama2-7b' -DATA_FILE=playground/data/training_sfttrain_koniq.json +DATA_FILE=playground/data/training_sft/train_koniq.json deepspeed --master_port 25801 q_align/train/train_mem.py \ --deepspeed ./scripts/zero3.json \ --model_name_or_path $LOAD \ diff --git a/scripts/onealign.sh b/scripts/onealign.sh new file mode 100644 index 0000000..994ebb2 --- /dev/null +++ b/scripts/onealign.sh @@ -0,0 +1,35 @@ +#!/bin/bash +LOAD='MAGAer13/mplug-owl2-llama2-7b' + +DATA_FILE=playground/data/training_sft/train_all.json +deepspeed --master_port 25801 q_align/train/train_mem.py \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path $LOAD \ + --version v1 \ + --data_path $DATA_FILE \ + --image_folder playground/data/ \ + --image_aspect_ratio pad \ + --group_by_modality_length True \ + --bf16 True \ + --output_dir ./one-align \ + --num_train_epochs 2 \ + --per_device_train_batch_size 32 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 1100 \ + --save_total_limit 2 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 2048 \ + --gradient_checkpointing True \ + --tune_visual_abstractor True \ + --freeze_vision_model False \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb \ No newline at end of file