From d753cb8e384b56c4d6b450fd0d7a68e0925c600f Mon Sep 17 00:00:00 2001 From: lkk <33276950+lkk12014402@users.noreply.github.com> Date: Mon, 15 Jan 2024 22:04:57 +0800 Subject: [PATCH] [NeuralChat] support llama series model for llava finetuning. (#948) --- .../examples/finetuning/multi_modal/README.md | 8 ++ .../examples/finetuning/multi_modal/train.py | 59 ++++++--- .../modeling/llava_models/__init__.py | 1 - .../modeling/llava_models/llava_llama.py | 113 ++++++++++++++++++ 4 files changed, 164 insertions(+), 17 deletions(-) create mode 100644 intel_extension_for_transformers/transformers/modeling/llava_models/llava_llama.py diff --git a/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/README.md b/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/README.md index 0ce7e3da1f9..5a26ebadc5c 100644 --- a/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/README.md +++ b/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/README.md @@ -3,6 +3,14 @@ Large Language and Vision Assistant (LLaVA) is a multi-modal training framework that proposed from [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485) and [Improved Baselines with Visual Instruction Tuning](https://arxiv.org/abs/2310.03744). This example demonstrates how to train mult-modal model on Intel Gaudi2. +## Validated Model List +|Pretrained model| LLaVA | +|------------------------------------|---| +|Mistral series| ✅| +|LLaMA series| ✅| + +**Note:** For Salesforce/codegen25-7b-* series models same with LLaMA architecture, need install `pip install transformers==4.33.2` refer [this](https://github.com/salesforce/CodeGen/issues/82) + ## Train LLaVA training consists of two stages: (1) feature alignment stage: use our 558K subset of the LAION-CC-SBU dataset to connect a *frozen pretrained* vision encoder to a *frozen LLM*; (2) visual instruction tuning stage: use 150K GPT-generated multimodal instruction-following data, plus around 515K VQA data from academic-oriented tasks, to teach the model to follow multimodal instructions. diff --git a/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/train.py b/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/train.py index cde09935a6d..eb8f3fc58af 100644 --- a/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/train.py +++ b/intel_extension_for_transformers/neural_chat/examples/finetuning/multi_modal/train.py @@ -25,9 +25,8 @@ import transformers -from transformers import AutoTokenizer, set_seed, BitsAndBytesConfig +from transformers import AutoTokenizer, set_seed, BitsAndBytesConfig, AutoConfig from transformers.integrations.deepspeed import is_deepspeed_available -from intel_extension_for_transformers.transformers.modeling.llava_models import LlavaMistralForCausalLM from llava_utils import * if is_hpu_available: @@ -133,19 +132,46 @@ def train(): low_cpu_mem_usage = False device_map = None - - model = LlavaMistralForCausalLM.from_pretrained( - model_args.model_name_or_path, - cache_dir=training_args.cache_dir, - load_in_4bit=training_args.bits == 4, - load_in_8bit=training_args.bits == 8, - low_cpu_mem_usage=low_cpu_mem_usage, - device_map=device_map, - quantization_config=quantization_config, - torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)), - trust_remote_code=model_args.trust_remote_code, - use_auth_token=model_args.use_auth_token - ) + config_kwargs = { + "cache_dir": training_args.cache_dir, + "trust_remote_code": model_args.trust_remote_code, + } + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + + use_fast = True + if config.architectures[0] == "LlamaForCausalLM": + from intel_extension_for_transformers.transformers.modeling.llava_models.llava_llama \ + import LlavaLlamaForCausalLM + model = LlavaLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + quantization_config=quantization_config, + torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)), + trust_remote_code=model_args.trust_remote_code, + use_auth_token=model_args.use_auth_token + ) + use_fast = False + elif config.architectures[0] == "MistralForCausalLM": + from intel_extension_for_transformers.transformers.modeling.llava_models.llava_mistral \ + import LlavaMistralForCausalLM + model = LlavaMistralForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + quantization_config=quantization_config, + torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)), + trust_remote_code=model_args.trust_remote_code, + use_auth_token=model_args.use_auth_token + ) + else: + raise ValueError("No llava implemention for the model {}".format(model_args.model_name_or_path)) # for training model.config.use_cache = False @@ -189,7 +215,8 @@ def make_inputs_require_grad(module, input, output): cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, padding_side="right", - # use_fast=False + trust_remote_code=model_args.trust_remote_code, + use_fast=use_fast ) tokenizer.pad_token = tokenizer.eos_token diff --git a/intel_extension_for_transformers/transformers/modeling/llava_models/__init__.py b/intel_extension_for_transformers/transformers/modeling/llava_models/__init__.py index 565876c2d41..ed04d17bdbe 100644 --- a/intel_extension_for_transformers/transformers/modeling/llava_models/__init__.py +++ b/intel_extension_for_transformers/transformers/modeling/llava_models/__init__.py @@ -15,4 +15,3 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .llava_mistral import LlavaMistralForCausalLM diff --git a/intel_extension_for_transformers/transformers/modeling/llava_models/llava_llama.py b/intel_extension_for_transformers/transformers/modeling/llava_models/llava_llama.py new file mode 100644 index 00000000000..d720b2db5a5 --- /dev/null +++ b/intel_extension_for_transformers/transformers/modeling/llava_models/llava_llama.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers import AutoConfig, AutoModelForCausalLM, \ + LlamaConfig, LlamaModel, LlamaForCausalLM + +from transformers.modeling_outputs import CausalLMOutputWithPast + +from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM + + +class LlavaConfig(LlamaConfig): + model_type = "llava" + + +class LlavaLlamaModel(LlavaMetaModel, LlamaModel): + config_class = LlavaConfig + + def __init__(self, config: LlamaConfig): + super(LlavaLlamaModel, self).__init__(config) + + +class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): + config_class = LlavaConfig + + def __init__(self, config): + super(LlavaLlamaForCausalLM, self).__init__(config) + self.model = LlavaLlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + if inputs_embeds is None: + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels + ) = self.prepare_inputs_labels_for_multimodal( + input_ids, + position_ids, + attention_mask, + past_key_values, + labels, + images + ) + + # pylint: disable=E1101 + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + images = kwargs.pop("images", None) + # pylint: disable=E1101 + _inputs = super().prepare_inputs_for_generation( + input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs + ) + if images is not None: + _inputs['images'] = images + return _inputs