diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 0ddb0acc9..54d9bcb81 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -230,6 +230,7 @@ def cloud_ai_100_exec_kv( stream: bool = True, write_io_dir: Optional[str] = None, automation=False, + prompt_to_lora_id_mapping: Optional[List[int]] = None, ): """ This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. @@ -277,6 +278,7 @@ def cloud_ai_100_exec_kv( stream=stream, write_io_dir=write_io_dir, full_batch_size=full_batch_size, + prompt_to_lora_id_mapping=prompt_to_lora_id_mapping, ) if full_batch_size is None: exec_info = [ @@ -313,6 +315,7 @@ def __init__( qpc_path: str, prompt: List[str], full_batch_size: Optional[int] = None, + prompt_to_lora_id_mapping: Optional[List[int]] = None, ctx_len: Optional[int] = None, generation_len: Optional[int] = None, device_id: Optional[List[int]] = None, @@ -342,6 +345,16 @@ def __init__( full_batch_size if full_batch_size else self._fetch_full_batch_size() ) # Check and fetch full batch size if CB is enabled + if prompt_to_lora_id_mapping: + self.prompt_to_lora_id_mapping_prefill = deque(prompt_to_lora_id_mapping) + if self.full_batch_size: + self.prompt_to_lora_id_mapping_decode = prompt_to_lora_id_mapping + else: + self.prompt_to_lora_id_mapping_decode = deque(prompt_to_lora_id_mapping) + else: + self.prompt_to_lora_id_mapping_prefill = None + self.prompt_to_lora_id_mapping_decode = None + self.set_tokenizer_params() # set tokenizer params # Initialize the storage variables. @@ -461,6 +474,16 @@ def prepare_decode_inputs(self): if self.batch_index is not None: decode_inputs["batch_index"] = self.batch_index + if self.prompt_to_lora_id_mapping_decode: + if self.full_batch_size: + first_batch_lora_ids = [self.prompt_to_lora_id_mapping_decode[i] for i in range(self.full_batch_size)] + decode_inputs["lora_ids"] = np.array(first_batch_lora_ids, dtype=np.int64).reshape( + self.full_batch_size, 1 + ) + else: + batch_lora_ids = [self.prompt_to_lora_id_mapping_decode.popleft() for i in range(self.batch_size)] + decode_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + return decode_inputs def _update_decode_input(self, outputs, position_ids, generation_len, decode_batch_id=None): @@ -549,6 +572,15 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i if decode_batch_id is not None: inputs["batch_index"] = decode_batch_id + if self.prompt_to_lora_id_mapping_prefill: + if self.full_batch_size: + inputs["lora_ids"] = np.array(self.prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64).reshape( + 1, 1 + ) + else: + batch_lora_ids = [self.prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] + inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + for i in range(num_chunks): chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ @@ -625,6 +657,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): self.session.set_buffers({"logits": logits_out_placeholder}) decode_pause_time += perf_counter() - start + + if self.prompt_to_lora_id_mapping_decode: + decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[ + batch_id_map[decode_batch_id] + ] + else: current_decode_ongoing[decode_batch_id] = False else: @@ -636,6 +674,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): ) generated_id_current_index[decode_batch_id] += 1 + return decode_pause_time def run_decode(self, decode_inputs, generation_len): diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index 85a66c527..bcdd79bdf 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -12,7 +12,7 @@ import numpy as np import torch -from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM, load_peft_weights +from peft import AutoPeftModelForCausalLM, PeftConfig, PeftModelForCausalLM, load_peft_weights from torch import nn from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList from transformers.generation.streamers import BaseStreamer @@ -21,6 +21,7 @@ from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform @@ -38,6 +39,7 @@ class QEffAutoPeftModelForCausalLM(QEFFBaseModel): Args: :model (nn.Module): PyTorch model + :finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification. .. code-block:: python @@ -80,6 +82,9 @@ def __init__(self, model: nn.Module): for adapter_name in model.peft_config } + def __repr__(self) -> str: + return self.__class__.__name__ + "\n" + self.model.__repr__() + @property def model_name(self) -> str: mname = self.model.get_base_model().__class__.__name__ + "-lora" @@ -145,6 +150,8 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): """ Args: :pretrained_name_or_path (str): Model card name from huggingface or local path to model directory. + :finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification. + :adapter_name (str): Name used to identify loaded adapter. :args, kwargs: Additional arguments to pass to peft.AutoPeftModelForCausalLM. """ if kwargs.get("full_batch_size"): @@ -152,7 +159,22 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): if kwargs.get("use_cache") is False: warnings.warn("Overriding to use_cache=True") kwargs["use_cache"] = True - obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) + + if kwargs.pop("finite_adapters", False): # initialize through finite_adapters class + obj = QEffAutoLoraModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=PeftConfig.from_pretrained( + pretrained_name_or_path + ).base_model_name_or_path, + **kwargs, + ) + if adapter_name := kwargs.pop("adapter_name", None): + obj.load_adapter(pretrained_name_or_path, adapter_name=adapter_name) + return obj + if len(args) == 0 or not isinstance(list(args)[0], str): + raise TypeError("Required adapter name argument in string format") + obj.load_adapter(pretrained_name_or_path, list(args)[0]) + else: + obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) return obj def export(self, export_dir: Optional[str] = None) -> str: diff --git a/QEfficient/peft/lora/__init__.py b/QEfficient/peft/lora/__init__.py new file mode 100644 index 000000000..361972ba7 --- /dev/null +++ b/QEfficient/peft/lora/__init__.py @@ -0,0 +1,12 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from QEfficient.peft.lora.auto import QEffAutoLoraModelForCausalLM + +__all__ = [ + "QEffAutoLoraModelForCausalLM", +] diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py new file mode 100644 index 000000000..2ccfac12a --- /dev/null +++ b/QEfficient/peft/lora/auto.py @@ -0,0 +1,387 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import hashlib +from pathlib import Path +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from peft import PeftConfig, load_peft_weights +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +import QEfficient +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.peft.lora.pytorch_transforms import LoraModelInputsTransform, TargetModulesTransform +from QEfficient.utils import constants, get_padding_shape_from_config +from QEfficient.utils.cache import to_hashable +from QEfficient.utils.logging_utils import logger + + +class QEffAutoLoraModelForCausalLM(QEFFAutoModelForCausalLM): + """ + QEff class for loading models with multiple LoRA adapters. Currently only Mistral and Llama model are supported. + Once exported and compiled, the qpc can perform mixed batch inference with provided `prompt_to_adapter_mapping`. + + Args: + :model (nn.Module): PyTorch model + :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. + + .. code-block:: python + + from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM + + m = QEffAutoPeftModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") + m.load_adapter("predibase/gsm8k", "gsm8k") + m.load_adapter("predibase/magicoder", "magicoder") + m.compile(num_cores=16, device_group=[0]) + + prompts=["code prompt", "math prompt", "generic"] + m.generate(prompts, device_group=[0], prompt_to_adapter_mapping=["magicoder","gsm8k_id","base"]) + + """ + + def __init__(self, model: nn.Module, continuous_batching: bool = False, **kwargs) -> None: + super().__init__(model, continuous_batching) + if self.model.__class__.__name__ not in ["QEffMistralForCausalLM", "QEffLlamaForCausalLM"]: + raise NotImplementedError( + f"Only QEffMistralForCausalLM and QEffLlamaForCausalLM model are supported but get {self.model.__class__.__name__}" + ) + + self.adapter_weights = {} + self.adapter_configs = {} + self.active_adapter_to_id = {} + + self.lora_rank = 0 + self.target_modules_for_all_adapters = [] + + def __repr__(self) -> str: + return self.__class__.__name__ + "\n" + self.model.__repr__() + + @property + def model_hash(self) -> str: + mhash = hashlib.sha256() + + # should use model config here + mhash.update(to_hashable(self.model.model.config.to_diff_dict())) + + # create active adapter config dict + active_adapter_configs = {} + for adpt in self.active_adapter_to_id.keys(): + active_adapter_configs[adpt] = self.adapter_configs[adpt].to_dict() + mhash.update(to_hashable(active_adapter_configs)) + + # create active adapter weight dict + active_adapter_weights = {} + for adpt in self.active_adapter_to_id.keys(): + active_adapter_weights[adpt] = {key: value.tolist() for key, value in self.adapter_weights[adpt].items()} + mhash.update(to_hashable(active_adapter_weights)) + + # ensure model will be exported again if order of adapters changes + mhash.update(to_hashable(self.active_adapter_to_id)) + + # noncb & cb should have different onnx & qpc + mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) + + mhash = mhash.hexdigest()[:16] + return mhash + + def download_adapter( + self, + adapter_model_id: str, + adapter_name: str, + adapter_weight: Optional[dict] = None, + adapter_config: Optional[PeftConfig] = None, + ): + """ + Loads a new adapter from huggingface hub or local path into CPU cache + + ``Mandatory`` Args: + :adapter_model_id (str): Adapter model ID from huggingface hub or local path + :adapter_name (str): Adapter name to be used to downloaded this adapter + ``Optional`` Args: + :adapter_weight (dict): Adapter weight tensors in dictionary format + :adapter_config (PeftConfig): Adapter config in the format of PeftConfig + """ + + # check if adapter name already loaded + if (adapter_name in self.adapter_weights.keys()) and (adapter_name in self.adapter_configs.keys()): + logger.warning(f"{adapter_name} has been loaded. Skip download.") + else: + if adapter_weight and adapter_config: # if sufficiently get adapter weight and adpater config + self.adapter_weights[adapter_name] = adapter_weight + self.adapter_configs[adapter_name] = adapter_config + else: # donwload with adapter_model_id + self.adapter_weights[adapter_name] = { + k: v.numpy().astype("float16") for k, v in load_peft_weights(adapter_model_id).items() + } + self.adapter_configs[adapter_name] = PeftConfig.from_pretrained(adapter_model_id) + + def load_adapter( + self, + adapter_model_id: str, + adapter_name: str, + adapter_weight: Optional[dict] = None, + adapter_config: Optional[PeftConfig] = None, + ): + """ + Load adapter into CPU cache and set it as active + + ``Mandatory`` Args: + :adapter_model_id (str): Adapter model ID from huggingface hub or local path + :adapter_name (str): Adapter name to be used to load this adapter + ``Optional`` Args: + :adapter_weight (dict): Adapter weight tensors in dictionary format + :adapter_config (PeftConfig): Adapter config in the format of PeftConfig + """ + + # check if adapter name already exist and activated + if adapter_name in self.active_adapter_to_id.keys(): + logger.warning(f"{adapter_name} exists and activated. Please provide a different adapter_name.") + else: + self.download_adapter(adapter_model_id, adapter_name, adapter_weight, adapter_config) + + # starting from the second adapter_name, check if adapters has same target module and rank + if list(self.adapter_configs.values())[0] and ( + self.adapter_configs[adapter_name].target_modules + != list(self.adapter_configs.values())[0].target_modules + ): + raise ValueError( + f"{adapter_name} must have same target_modules as {list(self.adapter_configs.keys())[0]}" + ) + if list(self.adapter_configs.values())[0] and ( + self.adapter_configs[adapter_name].r != list(self.adapter_configs.values())[0].r + ): + raise ValueError(f"{adapter_name} must have same rank as {list(self.adapter_configs.keys())[0]}") + + # set active adapter id to current max if adapter_name is new + if adapter_name not in self.active_adapter_to_id.keys(): + self.active_adapter_to_id[adapter_name] = len(self.active_adapter_to_id) + 1 # reserve 0 for base + + return self.active_adapter_to_id[adapter_name] + + def unload_adapter(self, adapter_name: str): + """ + Deactivate adpater and remove it from CPU cache + + ``Mandatory`` Args: + :adapter_name (str): Adapter name to be unloaded + """ + + # step1: remove from active list if it's there + if adapter_name not in self.active_adapter_to_id.keys(): + logger.info(f"Adapter name {adapter_name} is not set active yet") + return False + + self.active_adapter_to_id.pop(adapter_name) + + # renumbering of active adapter id + for index, (key, value) in enumerate(self.active_adapter_to_id.items()): + self.active_adapter_to_id[key] = index + 1 + + logger.warning(f"Deleting {adapter_name} from active adapters.") + if self.onnx_path or self.qpc_path: + logger.warning("Please redo compile_and_export() to reflect the active adapters changes.") + self.onnx_path = None + self.qpc_path = None + + # step2: delete from cache + if adapter_name in self.adapter_weights.keys() and adapter_name in self.adapter_configs.keys(): + self.adapter_weights.pop(adapter_name) + self.adapter_configs.pop(adapter_name) + logger.warning(f"Unloading {adapter_name} from CPU cache.") + + return True + + def set_adapter(self, adapter_name: str): + raise NotImplementedError("Set adapter is not supported in finite_adapters mode") + + def _load_adapter_weights_to_model(self): + "Loads adapter weights to the model's multilora layer in a stacked format" + + num_hidden_layers = len(self.model.model.layers) + for i in range(num_hidden_layers): + for target_module in self.target_modules_for_all_adapters: + # stack all adapters weights + a_tensor_list = list(range(len(self.active_adapter_to_id) + 1)) + b_tensor_list = list(range(len(self.active_adapter_to_id) + 1)) + s_tensor_list = list(range(len(self.active_adapter_to_id) + 1)) + + for lora_name, lora_id in self.active_adapter_to_id.items(): + if target_module in ["q_proj", "k_proj", "v_proj", "o_proj"]: + a_tensor_list[lora_id] = torch.from_numpy( + self.adapter_weights[lora_name][ + f"base_model.model.model.layers.{i}.self_attn.{target_module}.lora_A.weight" + ] + ) + b_tensor_list[lora_id] = torch.from_numpy( + self.adapter_weights[lora_name][ + f"base_model.model.model.layers.{i}.self_attn.{target_module}.lora_B.weight" + ] + ) + else: + raise NotImplementedError("Target module not supported!!") + + s_tensor_list[lora_id] = torch.tensor( + self.adapter_configs[lora_name].lora_alpha / self.adapter_configs[lora_name].r, + dtype=torch.float16, + ) + + # dummy zero tensor for base model + a_tensor_list[0] = torch.zeros_like(a_tensor_list[1]) + b_tensor_list[0] = torch.zeros_like(b_tensor_list[1]) + s_tensor_list[0] = torch.zeros_like(s_tensor_list[1]) + + # stack weight tensors + stacked_lora_a = ( + torch.stack(a_tensor_list, dim=0).unsqueeze(1).transpose(2, 3) + ) # + stacked_lora_b = ( + torch.stack(b_tensor_list, dim=0).unsqueeze(1).transpose(2, 3) + ) # + stacked_lora_s = ( + torch.stack(s_tensor_list, dim=0).unsqueeze(1).unsqueeze(2).unsqueeze(3) + ) # + + # stored weight to corresponding ops + if target_module == "q_proj": + module = self.model.model.layers[i].self_attn.q_proj + elif target_module == "k_proj": + module = self.model.model.layers[i].self_attn.k_proj + elif target_module == "v_proj": + module = self.model.model.layers[i].self_attn.v_proj + elif target_module == "o_proj": + module = self.model.model.layers[i].self_attn.o_proj + else: + raise NotImplementedError("Target module not supported!!") + + module.lora_a_weights.copy_(stacked_lora_a) + module.lora_b_weights.copy_(stacked_lora_b) + module.lora_scalings.copy_(stacked_lora_s) + + def _init_adapter_model(self): + "Initialize the fixed lora model with multiple adapter weigths standby" + + # set lora rank + self.lora_rank = list(self.adapter_configs.values())[0].r + + # do the module replacement + _, transformed = LoraModelInputsTransform.apply(self.model) + + self.target_modules_for_all_adapters = list(self.adapter_configs.values())[0].target_modules + _, transformed = TargetModulesTransform.apply( + self.model, self.target_modules_for_all_adapters, self.lora_rank, len(self.active_adapter_to_id) + ) + + # load_weight to model + self._load_adapter_weights_to_model() + + def export(self, export_dir: Optional[str] = None) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + We currently don't support exporting non-transformed models. Please refer to the ``convert_to_cloud_bertstyle`` function in the **Low-Level API** for a legacy function that supports this." + + ``Optional`` Args: + does not any arguments. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + + # initialize the adapter model + if len(self.active_adapter_to_id) == 0: + raise ValueError( + "Please use load_adapter() to add at least one adapter; otherwise, refer to QEFFAutoModelForCausalLM for base model usage" + ) + + self._init_adapter_model() + + bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + fbs = constants.ONNX_EXPORT_EXAMPLE_FBS + kv_cache_shape = get_padding_shape_from_config( + self.model.config, fbs if self.continuous_batching else bs, seq_len + ) + example_inputs = { + "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), + "position_ids": torch.arange(seq_len, dtype=torch.int64).view(bs, seq_len), + "past_key_values": [[] for _ in range(self.num_layers)], + "lora_ids": torch.zeros(bs, dtype=torch.int64).view(bs, 1), + } + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "lora_ids": {0: "batch_size"}, + } + output_names = ["logits"] + for i in range(self.num_layers): + for kv in ["key", "value"]: + example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + dynamic_axes[f"past_{kv}.{i}"] = { + 0: "full_batch_size" if self.continuous_batching else "batch_size", + 2: "ctx_len", + } + output_names.append(f"past_{kv}.{i}_RetainedState") + + if self.continuous_batching: + example_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + dynamic_axes["batch_index"] = {0: "batch_size"} + + return self._export( + example_inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + ) + + def generate( + self, + tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer], + prompts: List[str], + device_id: List[int] = None, + prompt_to_adapter_mapping: List[str] = None, + runtime: str = "AI_100", + **kwargs, + ): + """ + This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. + This is a sequential execution based on the ``batch_size`` of the compiled model and the number of prompts passed. + If the number of prompts cannot be divided by the ``batch_size``, the last unfulfilled batch will be dropped. + + ``Mandatory`` Args: + :tokenizer (PreTrainedTokenizerFast or PreTrainedTokenizer): The tokenizer used in the inference + :prompts (List[str]): List of prompts to run the execution. + :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model + :prompt_to_adapter_mapping (List[str]): The sequence of the adapter names will be matched with sequence of prompts and corresponding adapters will be used for the prompts."base" for base model (no adapter). + ``optional`` Args: + :runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100". + + """ + if runtime != "AI_100": + raise ValueError("Only AI_100 runtime is supported right now via generate API") + if not isinstance(self.qpc_path, Path): + raise TypeError("Please run compile API first!") + generation_len = kwargs.pop("generation_len", None) + + if not prompt_to_adapter_mapping: + prompt_to_adapter_mapping = ["base" for _ in range(len(prompts))] + + if len(prompt_to_adapter_mapping) != len(prompts): + raise RuntimeError( + f"Number of prompts should match number of prompt_to_adapter_mapping, got len(prompts) = {len(prompts)}, len(prompt_to_adapter_mapping) = {len(prompt_to_adapter_mapping)}" + ) + + return QEfficient.cloud_ai_100_exec_kv( + tokenizer, + self.qpc_path, + prompt=prompts, + device_id=device_id, + generation_len=generation_len, + prompt_to_lora_id_mapping=[ + self.active_adapter_to_id[name] if name != "base" else 0 for name in prompt_to_adapter_mapping + ], + ) diff --git a/QEfficient/peft/lora/layers.py b/QEfficient/peft/lora/layers.py new file mode 100644 index 000000000..f197eb7ea --- /dev/null +++ b/QEfficient/peft/lora/layers.py @@ -0,0 +1,67 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import math +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from QEfficient.customop import CtxGatherFuncCB + + +class LinearMultiLoRA(nn.Linear): + def multilora_init(self, lora_rank, max_num_adapters): + if lora_rank < 1 or max_num_adapters < 1: + raise ValueError("lora_rank and max_num_adapters must be greater or equal to 1") + + self.max_num_adapters = max_num_adapters + self.lora_rank = lora_rank + + self.lora_a_weights = nn.Parameter( + self.weight.new_zeros(self.max_num_adapters + 1, 1, self.in_features, self.lora_rank) + ) + self.lora_a_weights.requires_grad = False + self.lora_b_weights = nn.Parameter( + self.weight.new_zeros(self.max_num_adapters + 1, 1, self.lora_rank, self.out_features) + ) + self.lora_b_weights.requires_grad = False + self.lora_scalings = torch.full((self.max_num_adapters + 1, 1, 1, 1), 1.0, dtype=torch.float) + + nn.init.kaiming_uniform_(self.lora_a_weights, a=math.sqrt(5)) + nn.init.zeros_(self.lora_b_weights) + + def forward(self, x: torch.Tensor, lora_ids: torch.Tensor): + result = F.linear(x, self.weight, bias=self.bias) + + # multilora implementation: lora_ids + other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1) + selected_lora_a_weights = CtxGatherFuncCB.apply( + self.lora_a_weights, lora_ids, other_indices_a + ) # + other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1) + selected_lora_b_weights = CtxGatherFuncCB.apply( + self.lora_b_weights, lora_ids, other_indices_b + ) # + other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1) + selected_lora_scalings = CtxGatherFuncCB.apply( + self.lora_scalings, lora_ids, other_indices_s + ) # + + selected_lora_a_weights = selected_lora_a_weights.squeeze(1) + selected_lora_b_weights = selected_lora_b_weights.squeeze(1) + selected_lora_scalings = selected_lora_scalings.squeeze(1) + + result = result + x @ selected_lora_a_weights @ selected_lora_b_weights * selected_lora_scalings + + return result + + +class LinearBase(nn.Linear): + def forward(self, x: torch.Tensor, **kwargs: Any): + return super().forward(x) diff --git a/QEfficient/peft/lora/lora_model.py b/QEfficient/peft/lora/lora_model.py new file mode 100644 index 000000000..456d3fdde --- /dev/null +++ b/QEfficient/peft/lora/lora_model.py @@ -0,0 +1,88 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from typing import List, Optional, Tuple, Union + +import torch +from transformers.modeling_outputs import ( + CausalLMOutputWithPast, +) + +from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM +from QEfficient.transformers.models.mistral.modeling_mistral import QEffMistralForCausalLM + + +class QEffLoraModelMistralForCausalLM(QEffMistralForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + lora_ids: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + kwargs["lora_ids"] = lora_ids + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + +class QEffLoraModelLlamaForCausalLM(QEffLlamaForCausalLM): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + lora_ids: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + kwargs["lora_ids"] = lora_ids + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) diff --git a/QEfficient/peft/lora/pytorch_transforms.py b/QEfficient/peft/lora/pytorch_transforms.py new file mode 100644 index 000000000..5e7463b97 --- /dev/null +++ b/QEfficient/peft/lora/pytorch_transforms.py @@ -0,0 +1,53 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from typing import Dict, Optional, Tuple + +from torch import nn + +from QEfficient.base.pytorch_transforms import ModuleMappingTransform +from QEfficient.peft.lora.layers import LinearBase, LinearMultiLoRA +from QEfficient.peft.lora.lora_model import QEffLoraModelLlamaForCausalLM, QEffLoraModelMistralForCausalLM +from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM +from QEfficient.transformers.models.mistral.modeling_mistral import QEffMistralForCausalLM + + +class LoraModelInputsTransform(ModuleMappingTransform): + _module_mapping = { + QEffMistralForCausalLM: QEffLoraModelMistralForCausalLM, + QEffLlamaForCausalLM: QEffLoraModelLlamaForCausalLM, + } + + +class TargetModulesTransform(ModuleMappingTransform): + _module_mapping = {nn.Linear: LinearMultiLoRA} + + _module_mapping_nontarget = {nn.Linear: LinearBase} + + # whole set of supported target modules for now (make sure **kwargs are passed in on modeling file) + all_modules = {"q_proj", "k_proj", "v_proj", "o_proj"} + + # a class method that deals with target module names + @classmethod + def apply( + cls, model: nn.Module, target_modules: Optional[Dict], lora_rank: int, max_num_adapters: int + ) -> Tuple[nn.Module, bool]: + transformed = False + nontarget_modules = {key for key in cls.all_modules if key not in target_modules} + + for name, module in model.named_modules(): + if repl_module := cls._module_mapping.get(type(module)): + if name.split(".")[-1] in target_modules: + module.__class__ = repl_module + if hasattr(module, "multilora_init"): + module.multilora_init(lora_rank, max_num_adapters) + transformed = True + elif name.split(".")[-1] in nontarget_modules: + module.__class__ = cls._module_mapping_nontarget.get(type(module)) + transformed = True + + return model, transformed diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 4a1870380..679b4a2f9 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -168,6 +168,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -190,9 +191,9 @@ def forward( value_states = torch.cat(value_states, dim=-1) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states, **kwargs) + key_states = self.k_proj(hidden_states, **kwargs) + value_states = self.v_proj(hidden_states, **kwargs) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -244,7 +245,7 @@ def forward( o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output, **kwargs) if not output_attentions: attn_weights = None @@ -273,6 +274,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -318,6 +320,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) # Cast to INT32 to avoid issue while running in ONNXRT @@ -374,6 +377,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -403,6 +407,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states @@ -443,6 +448,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -515,6 +521,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) else: layer_outputs = decoder_layer( @@ -525,6 +532,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = layer_outputs[0] diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index ae913b42d..9fc71dc02 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -136,12 +136,13 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states, **kwargs) + key_states = self.k_proj(hidden_states, **kwargs) + value_states = self.v_proj(hidden_states, **kwargs) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -187,7 +188,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output, **kwargs) if not output_attentions: attn_weights = None @@ -215,6 +216,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -294,6 +296,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = layer_outputs[0] @@ -413,6 +416,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -439,7 +443,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -459,6 +462,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) # Cast to int32 to avoid ONNXRT issue diff --git a/docs/source/hl_api.md b/docs/source/hl_api.md index 8ddc65ca7..798157be0 100644 --- a/docs/source/hl_api.md +++ b/docs/source/hl_api.md @@ -16,6 +16,12 @@ :members: ``` +## `QEffAutoLoraModelForCausalLM` +```{eval-rst} +.. autoclass:: QEfficient.lora.auto.QEffAutoLoraModelForCausalLM + :members: +``` + ## `export` ```{eval-rst} .. automodule:: QEfficient.exporter.export_hf_to_cloud_ai_100 diff --git a/examples/lora_models.py b/examples/lora_models.py new file mode 100644 index 000000000..b4a1cd921 --- /dev/null +++ b/examples/lora_models.py @@ -0,0 +1,133 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +## This example works on continuous batching with different lora adapters in the same batch ## + +from QEfficient import QEffAutoPeftModelForCausalLM +from QEfficient.utils import load_hf_tokenizer + +base_model_name = "mistralai/Mistral-7B-v0.1" +seq_len = 128 +ctx_len = 256 +full_batch_size = 4 +device_group = [0] + +## STEP 1 -- init base model +qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained( + "predibase/gsm8k", "gsm8k", continuous_batching=True, finite_adapters=True +) + +# (alternative) non-cb compilation +# qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained( +# "predibase/gsm8k", "gsm8k", continuous_batching=False, finite_adapters=True +# ) + +## STEP 2 -- load adapter adapter +qeff_model.load_adapter("predibase/tldr_content_gen", "tldr_content_gen") + +qeff_model.load_adapter("predibase/dbpedia", "dbpedia") + +# STEP 2 (optional) -- unload adapter +unload_status = qeff_model.unload_adapter("dbpedia") +print(f"Unloading dbpedia success: {unload_status}") + + +## STEP 3 -- export & compile qeff model +qpc_path = qeff_model.compile( + batch_size=1, + full_batch_size=full_batch_size, + prefill_seq_len=seq_len, + ctx_len=ctx_len, + num_devices=len(device_group), + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, +) + +# (alternative) non-cb compilation +# qpc_path = qeff_model.compile( +# batch_size=2, +# prefill_seq_len=seq_len, +# ctx_len=ctx_len, +# num_devices=len(device_group), +# num_cores=16, +# mxfp6_matmul=True, +# mxint8_kv_cache=True, +# ) + +## STEP 4 -- run inference on the generate function +prompts = [ + """Please answer the following question: James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?\n\nAnswer:""", + """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: Harvard shrank its insect-inspired microrobot to the size of a penny\n\nContent:""", + """Please answer the following question: Gene is sewing a quilt out of old souvenir t-shirts. He has one shirt from each vacation he has been on. Every shirt is its own quilt block. Each row is made of blocks from a different year of vacations. He goes on four vacations a year and has been vacationing since he was 23 years old. He is now 34. How many quilt blocks does he have in total?\n\nAnswer:""", + """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: New neurons for life? Old people can still make fresh brain cells, study finds\n\nContent:""", + """Please answer the following question: Harry slept 9 hours last night. His friend James slept only 2/3 of what Harry slept. How many more hours did Harry sleep than James?\n\nAnswer:""", + """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: Latest success from Google’s AI group: Controlling a fusion reactor\n\nContent:""", + """Please answer the following question: Gene is sewing a quilt out of old souvenir t-shirts. He has one shirt from each vacation he has been on. Every shirt is its own quilt block. Each row is made of blocks from a different year of vacations. He goes on four vacations a year and has been vacationing since he was 23 years old. He is now 34. How many quilt blocks does he have in total?\n\nAnswer:""", + """The following headline is the headline of a news report. Please write the content of the news passage based on only this headline.\n\nHeadline: TikTok Picks Streaming Service Audius to Power New ‘Sounds’ Library\n\nContent:""", +] + +qeff_model.generate( + tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name), + prompts=prompts, + device_id=device_group, + prompt_to_adapter_mapping=[ + "gsm8k", + "tldr_content_gen", + "gsm8k", + "base", + "gsm8k", + "tldr_content_gen", + "gsm8k", + "tldr_content_gen", + ], +) + + +""" +expected response: + +<1> +He runs 3*3=<<3*3=9>>9 sprints a week +So he runs 9*60=<<9*60=540>>540 meters a week +#### 540 + +<2> +Researchers at Harvard have created a microrobot that is smaller than a penny. The robot is made of a flexible polymer that can be folded and unfolded to move. It is powered by a laser and can be controlled by a computer. The robot is able to move on its own, but it can also be controlled remotely. It can be used to deliver drugs or to perform other tasks. A 1-minute video that shows the robot in action is available in the article. + +<3> +He has been on 34-23=<<34-23=11>>11 vacations +He has 11*4=<<11*4=44>>44 blocks +#### 44 + +<4> +A new study has found that old people can still make fresh brain cells. The study was conducted by researchers at the University of California, San Francisco. They found that the brains of people in their 70s and 80s were still able brain cells + +Content: + +A new study has found that the brain of an old person can still make new neurons. The study was conducted by a team of researchers from the University of California, Los Angeles. The team studied the brains that were able to make new neurons. The team found that the brains of these people were able to make new neurons in the hippocampus, which is the part of the brain that is responsible for memory and learning. The team also found that the brains of these people were able to make new neurons in the cortex, which is the part of the brain that is responsible for thinking and reasoning. The team also found that the brains of these people were able to make new neurons in the cerebellum, which + +<5> +James slept 2/3 * 9 = <<2/3*9=6>>6 hours. +Harry slept 9 - 6 = <<9-6=3>>3 hours more than James. +#### 3 + +<6> +'s AI group has developed a system that can control a fusion reactor. The system uses a deep reinforcement learning +He has been alive for 11 years, so he has been alive for 11 x 365 = 4,055 days. +He has been alive for 4,055 days, so he has been alive for 4,055 x 24 = 97,300 hours. +He has been alive for 97,300 hours, so he has been alive for 97,300 x 60 = 5,838,000 minutes. +He has been alive for 5,838,000 minutes, so he has been alive for 5,83 kennis + +<7> +He has been on 34-23=<<34-23=11>>11 vacations. +He has 11*4=<<11*4=44>>44 blocks. +#### 44 + +<8> +TikTok has partnered with Audius to power its new Sounds library. The Sounds library will allow users to discover and share sounds from a wide range of creators. Audius is a music streaming platform that allows artists to upload their music and share it with fans. It has a community of over 1.5 million users. TikTok has been working on the Sounds library for over a year. The library will be available in the US, Canada, and Australia. +""" diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py new file mode 100644 index 000000000..a91555b3a --- /dev/null +++ b/tests/peft/lora/test_lora_model.py @@ -0,0 +1,234 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from pathlib import Path +from time import perf_counter + +import numpy as np +import pytest +from peft import LoraConfig +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient import QEffAutoPeftModelForCausalLM +from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM +from QEfficient.utils import load_hf_tokenizer + +configs = [ + pytest.param( + AutoConfig.for_model( + "llama", num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, hidden_size=128 + ), + LoraConfig(target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM", lora_alpha=8), + id="llama-2l-4h-2kvh-128d-qv", + ), + pytest.param( + AutoConfig.for_model( + "mistral", num_hidden_layers=2, num_attention_heads=4, num_key_value_heads=2, hidden_size=128 + ), + LoraConfig(target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM", lora_alpha=6), + id="mistral-2l-4h-128d-qv", + ), +] + +model_samples = [ + pytest.param("mistralai/Mistral-7B-v0.1", "predibase/gsm8k", "predibase/dbpedia"), + pytest.param( + "meta-llama/Meta-Llama-3-8B", + "hallisky/lora-type-narrative-llama-3-8b", + "hallisky/lora-grade-elementary-llama-3-8b", + ), +] + + +def create_lora_base_model(base_config): + base_model = AutoModelForCausalLM.from_config(base_config, attn_implementation="eager") + lora_base_model = QEffAutoLoraModelForCausalLM(base_model) + + return lora_base_model + + +# test model initialization using __init__ approach +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) +def test_auto_lora_model_for_causal_lm_init(base_model_name, adapter_id_0, adapter_id_1): + model_hf = AutoModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + qeff_model = QEffAutoLoraModelForCausalLM(model_hf) + + assert len(qeff_model.adapter_weights) == 0 + assert len(qeff_model.adapter_configs) == 0 + assert len(qeff_model.active_adapter_to_id) == 0 + + +# test model initialization using from_pretrained approach +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) +def test_auto_lora_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1): + qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=base_model_name, num_hidden_layers=1 + ) + + assert len(qeff_model.adapter_weights) == 0 + assert len(qeff_model.adapter_configs) == 0 + assert len(qeff_model.active_adapter_to_id) == 0 + + +# test peft model initialization using from_pretrained approach +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples) +def test_auto_peft_model_for_causal_lm_from_pretrained(base_model_name, adapter_id_0, adapter_id_1): + qeff_model = QEffAutoPeftModelForCausalLM.from_pretrained( + adapter_id_0, "id_0", finite_adapters=True, num_hidden_layers=1 + ) + qeff_model_tmp = QEffAutoPeftModelForCausalLM.from_pretrained( + adapter_id_0, adapter_name="id_0", finite_adapters=True, num_hidden_layers=1 + ) + + assert qeff_model.active_adapter_to_id == qeff_model_tmp.active_adapter_to_id + del qeff_model_tmp + assert isinstance(qeff_model, QEffAutoLoraModelForCausalLM) + assert len(qeff_model.adapter_weights) == 1 + assert len(qeff_model.adapter_configs) == 1 + assert len(qeff_model.active_adapter_to_id) == 1 + + # test pass without adapter name + with pytest.raises(TypeError): + QEffAutoLoraModelForCausalLM.from_pretrained(adapter_id_0, finite_adapters=True, num_hidden_layers=1) + + # test pass with adapter name as integer + with pytest.raises(TypeError): + QEffAutoLoraModelForCausalLM.from_pretrained(adapter_id_0, 0, finite_adapters=True, num_hidden_layers=1) + + +# test the init assertion for models that are not supported +@pytest.mark.parametrize("base_model_name", ["distilbert/distilgpt2"]) +def test_auto_lora_model_for_causal_lm_init_from_unsupported_model(base_model_name): + model_hf = AutoModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + with pytest.raises(NotImplementedError): + QEffAutoLoraModelForCausalLM(model_hf) + + with pytest.raises(NotImplementedError): + QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + + +# test model hash +def test_auto_lora_model_for_causal_lm_hash(): + base_config_0, adapter_config_0 = configs[0].values + base_config_1, adapter_config_1 = configs[1].values + + qeff_model_0 = create_lora_base_model(base_config_0) + qeff_model_0.load_adapter( + "dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} + ) + qeff_model_0.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))} + ) + model_hash_0_0 = qeff_model_0.model_hash + + qeff_model_1 = create_lora_base_model(base_config_1) + qeff_model_1.load_adapter( + "dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} + ) + qeff_model_1.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))} + ) + model_hash_1_0 = qeff_model_1.model_hash + + qeff_model_0_1 = create_lora_base_model(base_config_0) + qeff_model_0_1.load_adapter( + "dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} + ) + qeff_model_0_1.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))} + ) + model_hash_0_1_0 = qeff_model_0_1.model_hash + + # check if same model, same adapter config, same adapter weight, result in same hash + assert model_hash_0_1_0 == model_hash_0_0 + + # check if same model, same adapter config, but different weight, result in different hash + qeff_model_0_1.unload_adapter("adapter_1") + qeff_model_0_1.unload_adapter("adapter_0") + qeff_model_0_1.load_adapter( + "dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.random.randn(3, 3)} + ) + qeff_model_0_1.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.random.randn(3, 3)} + ) + model_hash_0_1_1 = qeff_model_0_1.model_hash + assert model_hash_0_1_1 != model_hash_0_0 + + # check base model configs difference result in different hash + assert model_hash_0_0 != model_hash_1_0 + + # check different adapter orders, result in different hash + qeff_model_1.unload_adapter("adapter_0") + qeff_model_1.unload_adapter("adapter_1") + qeff_model_1.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_1, adapter_weight={"weights": np.ones((3, 3))} + ) + qeff_model_1.load_adapter( + "dummy_id", "adapter_0", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} + ) + model_hash_1_1 = qeff_model_1.model_hash + assert model_hash_1_1 != model_hash_1_0 + + # check if same adapter name, but different config, result in different hash + qeff_model_0.unload_adapter("adapter_1") + qeff_model_0.load_adapter( + "dummy_id", "adapter_1", adapter_config=adapter_config_0, adapter_weight={"weights": np.ones((3, 3))} + ) + model_hash_0_1 = qeff_model_0.model_hash + assert model_hash_0_1 != model_hash_0_0 + + +# test download_adapter(), load_adapter() and unload_adapter() +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[1:]) +def test_auto_lora_model_for_causal_lm_load_unload_adapter(base_model_name, adapter_id_0, adapter_id_1): + qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + + qeff_model.download_adapter(adapter_id_0, "adapter_0") + qeff_model.download_adapter(adapter_id_1, "adapter_1") + + qeff_model.load_adapter(adapter_id_0, "adapter_0") + + assert not qeff_model.unload_adapter("adapter_1") # not active adapter + assert qeff_model.unload_adapter("adapter_0") # valid unload + + +# test the export, export caching, compile, generate workflow +@pytest.mark.on_qaic +@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1]) +def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, adapter_id_0, adapter_id_1, tmp_path): + qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1) + + qeff_model.load_adapter(adapter_id_0, "adapter_0") + qeff_model.load_adapter(adapter_id_1, "adapter_1") + + # export + start = perf_counter() + qeff_model.export(export_dir=tmp_path) + end = perf_counter() + export_time_0 = end - start + model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash) + assert model_path.is_dir() + assert Path(qeff_model.onnx_path).is_file() + + # test export caching + start = perf_counter() + qeff_model.export(export_dir=tmp_path) + end = perf_counter() + export_time_1 = end - start + assert export_time_1 < export_time_0 + + # test compile + qeff_model.compile(prefill_seq_len=32, ctx_len=64) + assert Path(qeff_model.qpc_path).is_dir() + + # test generate + prompts = ["hello!", "hi", "hello, my name is", "hey"] + qeff_model.generate( + tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name), + prompts=prompts, + device_id=[0], + prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"], + )