Skip to content

Commit

Permalink
workaround low-mem patch
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed May 29, 2024
1 parent 79bc89b commit e4e32b6
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,36 @@
import torch


def make_sure_no_tensor_in_meta_device(
model,
use_triton: bool,
desc_act: bool,
group_size: int,
bits: int,
disable_exllama: bool,
disable_exllamav2: bool,
use_marlin: bool = False,
use_tritonv2: bool = False,
):
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear #pylint: disable=import-outside-toplevel,import-error
QuantLinear = dynamically_import_QuantLinear(
use_triton,
desc_act,
group_size,
bits=bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_marlin=use_marlin,
use_tritonv2=use_tritonv2
)
for n, m in model.named_modules():
bias = getattr(m, "bias", None)
if bias:
if isinstance(m, QuantLinear) and bias.device == torch.device("meta"):
m.register_buffer(
"bias", torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu")
)

def replace_module_peft(self, parent_module, child_name, new_module, old_module):

# replace the lora linear
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
from peft.tuners.lora.model import LoraModel
import torch.distributed
from transformers import AutoModelForCausalLM, TrainingArguments
from transformers.modeling_utils import is_fsdp_enabled
import torch
import os

import importlib

class AutoGPTQAccelerationPlugin(AccelerationPlugin):

Expand All @@ -48,7 +49,6 @@ def __init__(self, configurations: Dict[str, Dict]):
)

def model_loader(self, model_name: str, **kwargs):

# guarded imports
# Third Party
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
Expand Down Expand Up @@ -80,18 +80,6 @@ def model_loader(self, model_name: str, **kwargs):
low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage")
attn_implementation = kwargs.get("attn_implementation")

if low_cpu_mem_usage:
# Note that low_cpu_mem_usage is typically set to transformers.modeling_utils.is_fsdp_enabled.
# e.g., https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990
# but not doing that now as AutoGPTQ will call make_sure_no_tensor_in_meta_device
# https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_base.py#L982C17-L982C51
# which does not properly check if a QuantLayer has a bias set or not,
# https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_utils.py#L514
raise ValueError(
"low_cpu_mem_usage set to True. This may raise error if model has no bias, "
"due to AutoGPTQ bug. Not supporting at the moment."
)

# there are some kwargs that we wont be passed to AutoModel, so we need
# to patch them in
_old_from_config = AutoModelForCausalLM.from_config
Expand All @@ -103,12 +91,25 @@ def model_loader(self, model_name: str, **kwargs):
)
AutoModelForCausalLM.from_config = _from_config # patch

if is_fsdp_enabled():
from .autogptq_utils import make_sure_no_tensor_in_meta_device
source = importlib.import_module("auto_gptq.modeling._utils")
original_obj = getattr(source, "make_sure_no_tensor_in_meta_device")
setattr(source, "make_sure_no_tensor_in_meta_device", make_sure_no_tensor_in_meta_device)
# reload and this should get the patched object
target_module = importlib.import_module("auto_gptq.modeling._base")
importlib.reload(target_module)
low_cpu_mem_usage = True

# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
# device_map is for inference only https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
# Thus we set it as below to effectively disable it.
device_map = (
{"": torch.cuda.current_device()} if torch.cuda.is_available() else None
)
device_map = {
"": (
torch.cuda.current_device() if not low_cpu_mem_usage
else "cpu"
) if torch.cuda.is_available() else None
}

# currently only enable triton_v2, because the triton kernels are the only ones
# that have backwards
Expand Down

0 comments on commit e4e32b6

Please sign in to comment.