Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround Low-Mem-Mode Patch for GPTQ-LoRA #26

Merged
merged 7 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,77 @@
# https://spdx.dev/learn/handling-license-info/

# Standard
from typing import Callable, List
from typing import Any, Callable, List
import importlib

# Third Party
from peft import LoraConfig
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ
import torch


# This function may be moved after merging
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
def _patch_target_module(
to_patch: str,
replace_with: Any,
target_module: str = None,
):
to_patch = to_patch.split(".")
assert len(to_patch) > 1, "must have an object to patch"

to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1]
to_patch = ".".join(to_patch)
source = importlib.import_module(to_patch)
original_obj = getattr(source, obj_name_to_patch)
setattr(source, obj_name_to_patch, replace_with)

if target_module is not None:
# reload and this should get the patched object
target_module = importlib.import_module(target_module)
importlib.reload(target_module)

# replace it
setattr(source, obj_name_to_patch, original_obj)


def make_sure_no_tensor_in_meta_device(
model,
use_triton: bool,
desc_act: bool,
group_size: int,
bits: int,
disable_exllama: bool,
disable_exllamav2: bool,
use_marlin: bool = False,
use_tritonv2: bool = False,
):
# Third Party
# guarded import
from auto_gptq.utils.import_utils import ( # pylint: disable=import-outside-toplevel,import-error
dynamically_import_QuantLinear,
)

QuantLinear = dynamically_import_QuantLinear(
use_triton,
desc_act,
group_size,
bits=bits,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_marlin=use_marlin,
use_tritonv2=use_tritonv2,
)
for _, m in model.named_modules():
bias = getattr(m, "bias", None)
if bias:
if isinstance(m, QuantLinear) and bias.device == torch.device("meta"):
m.register_buffer(
"bias",
torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"),
)


def replace_module_peft(self, parent_module, child_name, new_module, old_module):

# replace the lora linear
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from peft import LoraConfig, prepare_model_for_kbit_training
from peft.tuners.lora.model import LoraModel
from transformers import AutoModelForCausalLM, TrainingArguments
from transformers.modeling_utils import is_fsdp_enabled
import torch
import torch.distributed

Expand All @@ -48,14 +49,15 @@ def __init__(self, configurations: Dict[str, Dict]):
)

def model_loader(self, model_name: str, **kwargs):

# guarded imports
# Third Party
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error

# Local
from .autogptq_utils import patch_forward_to_view_attributes_before_call #pylint: disable=import-outside-toplevel
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
patch_forward_to_view_attributes_before_call,
)

# Currently we allow only a quantized checkpoint to be loaded, we do not
# implement the quantization process here.
Expand Down Expand Up @@ -84,20 +86,6 @@ def model_loader(self, model_name: str, **kwargs):
low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage")
attn_implementation = kwargs.get("attn_implementation")

if low_cpu_mem_usage:
# Note that low_cpu_mem_usage is typically set to
# transformers.modeling_utils.is_fsdp_enabled.
# e.g.,
# https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990
# but not doing that now as AutoGPTQ will call make_sure_no_tensor_in_meta_device
# https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_base.py#L982C17-L982C51
# which does not properly check if a QuantLayer has a bias set or not,
# https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_utils.py#L514
raise ValueError(
"low_cpu_mem_usage set to True. This may raise error if model has no bias, "
"due to AutoGPTQ bug. Not supporting at the moment."
)

# there are some kwargs that we wont be passed to AutoModel, so we need
# to patch them in
_old_from_config = AutoModelForCausalLM.from_config
Expand All @@ -107,14 +95,40 @@ def model_loader(self, model_name: str, **kwargs):
)
AutoModelForCausalLM.from_config = _from_config # patch

# NOTE: need to set the device map as below as we want to
# use AutoGPTQ for training.
# this is a HF method that checks if the low_cpu_mem mode is enabled
# via HF accelerate
if is_fsdp_enabled():
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
_patch_target_module,
make_sure_no_tensor_in_meta_device,
)

# We patch `make_sure_no_tensor_in_meta_device`
# from autogptq to avoid errors on models without bias
_patch_target_module(
to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
replace_with=make_sure_no_tensor_in_meta_device,
target_module="auto_gptq.modeling._base",
)
low_cpu_mem_usage = True

# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
# device_map is for inference only
# ref: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
# Thus we set it as below to effectively disable it.
device_map = (
{"": torch.cuda.current_device()} if torch.cuda.is_available() else None
)
# https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
# For low_cpu_mem_usage = True, we have to set the device map to load checkpoints to "cpu"
# to avoid gpu consumption before train
# This approach will divert consumption to cpu memory,
# a better approach would be to load the checkpoints to meta device
# QLoRA is currently implemented by the former approach and will encounter the same issue.
# see https://github.com/huggingface/transformers/pull/25107#issuecomment-2134833262
device_map = {
"": (
(torch.cuda.current_device() if not low_cpu_mem_usage else "cpu")
if torch.cuda.is_available()
else None
)
}

# currently only enable triton_v2, because the triton kernels are the only ones
# that have backwards
Expand Down Expand Up @@ -204,9 +218,11 @@ def augmentation(
# Third Party
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error

# Local
from .autogptq_utils import create_new_module_peft, replace_module_peft #pylint: disable=import-outside-toplevel
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
create_new_module_peft,
replace_module_peft,
)

(peft_config,) = modifiable_args # unpack modifiable args

Expand Down
3 changes: 2 additions & 1 deletion scripts/benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ We currently compute the memory values in the report by taking the largest of su
For allocated memory value
```
max([
stage0_mem,
stage0_mem + stage1_allocated_delta,
stage0_mem + stage1_allocated_delta + stage2_allocated_delta,
...
Expand All @@ -173,13 +174,13 @@ max([
For peak memory value
```
max([
stage0_mem,
stage0_mem + stage1_allocated_delta + stage1_peaked_delta,
stage0_mem + stage1_allocated_delta + stage2_allocated_delta + stage2_peaked_delta,
...
])
```

Notice that we do not include `stage0_mem` alone when computing the max value. This is to avoid misleading comparisons between GPTQ-LoRA and others. GPTQ-LoRA + FSDP currently does not support low-memory mode as mentioned [here](https://github.com/foundation-model-stack/fms-acceleration/issues/18). The `stage0_mem` value of GPTQ-LoRA + FSDP will reflect a larger than expected value as it is loaded fully before the trainer is initialized and then subsequently will be sharded internally in `trainer.prepare`. This might cause some misleading comparisons when other variants are loaded in low-memory mode and have smaller `stage0_mem` memory consumption than GPTQ-LoRA + FSDP. Once low-memory mode is supported for GPTQ-LoRA, we will include `stage0_mem` back inside the max computation

We compare memory values between Nvidia-SMI and Torch in this PR - [Memory Benchmarking](https://github.com/foundation-model-stack/fms-acceleration/pull/14).

Expand Down
4 changes: 2 additions & 2 deletions scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]:
return 0, 0

trainer_stage_order = [
(HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, False),
(HF_TRAINER_LOG_GPU_STAGE_INIT, False),
(HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, True),
(HF_TRAINER_LOG_GPU_STAGE_INIT, True),
(HF_TRAINER_LOG_GPU_STAGE_TRAIN, True),
]
alloc_running_sum = 0
Expand Down
Loading