Skip to content

Commit

Permalink
Add kernel torch.compile hook (#1265)
Browse files Browse the repository at this point in the history
* add compile() hook for every kernel

* ruff compat
  • Loading branch information
Qubitium authored Feb 12, 2025
1 parent 76b169e commit ff72d31
Show file tree
Hide file tree
Showing 101 changed files with 124 additions and 348 deletions.
6 changes: 2 additions & 4 deletions examples/benchmark/generation_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@

import torch
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, GenerationConfig
from transformers.generation.logits_process import LogitsProcessor

from gptqmodel import BACKEND, GPTQModel, QuantizeConfig
from gptqmodel.utils.progress import ProgressBar

from transformers import AutoTokenizer, GenerationConfig
from transformers.generation.logits_process import LogitsProcessor

logger = logging.getLogger(__name__)

Expand Down
2 changes: 0 additions & 2 deletions examples/benchmark/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


try:
from optimum.intel.utils.modeling_utils import bind_cores_for_best_perf
bind_cores_for_best_perf()
Expand All @@ -30,7 +29,6 @@

import argparse


parser = argparse.ArgumentParser(description="Benchmark IPEX vs HF on a pre-trained model.")
parser.add_argument("--model", type=str, required=True, help="Path or name of the pre-trained model.")
parser.add_argument("--cores", type=int, default=8, help="Number of CPU cores to use.")
Expand Down
4 changes: 1 addition & 3 deletions examples/benchmark/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import argparse
import os

from transformers import AutoTokenizer

from gptqmodel.utils import Perplexity

from transformers import AutoTokenizer

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

Expand Down
4 changes: 1 addition & 3 deletions examples/evaluation/run_language_modeling_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@

import datasets
import torch
from transformers import AutoTokenizer

from gptqmodel import BACKEND, GPTQModel, QuantizeConfig
from gptqmodel.eval_tasks import LanguageModelingTask
from gptqmodel.utils.torch import torch_empty_cache

from transformers import AutoTokenizer

DATASET = "tatsu-lab/alpaca"
WITH_INPUT_TEMPLATE = "Instruction:\n{instruction}\n\nInput:\n{input}\n\nOutput:\n"
Expand Down
4 changes: 1 addition & 3 deletions examples/evaluation/run_sequence_classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@

import datasets
import torch
from transformers import AutoTokenizer

from gptqmodel import BACKEND, GPTQModel, QuantizeConfig
from gptqmodel.eval_tasks import SequenceClassificationTask
from gptqmodel.utils.torch import torch_empty_cache

from transformers import AutoTokenizer

DATASET = "cardiffnlp/tweet_sentiment_multilingual"
TEMPLATE = "Question:What's the sentiment of the given text? Choices are {labels}.\nText: {text}\nAnswer:"
Expand Down
4 changes: 1 addition & 3 deletions examples/evaluation/run_text_summarization_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@

import datasets
import torch
from transformers import AutoTokenizer, GenerationConfig

from gptqmodel import BACKEND, GPTQModel, QuantizeConfig
from gptqmodel.eval_tasks import TextSummarizationTask
from gptqmodel.utils.torch import torch_empty_cache

from transformers import AutoTokenizer, GenerationConfig

os.system("pip install py7zr")

Expand Down
1 change: 0 additions & 1 deletion examples/inference/run_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from transformers import AutoModelForCausalLM, AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ")
quantized_model = AutoModelForCausalLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ")
print(tokenizer.decode(quantized_model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(quantized_model.device))[0]))
Expand Down
4 changes: 1 addition & 3 deletions examples/inference/run_with_different_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
import sys
from argparse import ArgumentParser

from transformers import AutoTokenizer

from gptqmodel import BACKEND, GPTQModel, QuantizeConfig, get_best_device

from transformers import AutoTokenizer

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
pretrained_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
Expand Down
4 changes: 1 addition & 3 deletions examples/quantization/basic_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

import os

from transformers import AutoTokenizer

from gptqmodel import GPTQModel, QuantizeConfig, get_best_device

from transformers import AutoTokenizer

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

Expand Down
4 changes: 1 addition & 3 deletions examples/quantization/basic_usage_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
# limitations under the License.

import torch
from transformers import AutoTokenizer

from gptqmodel import GPTQModel
from gptqmodel.quantization.config import AutoRoundQuantizeConfig # noqa: E402

from transformers import AutoTokenizer

pretrained_model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quantized_model_id = "./autoround/TinyLlama-1.1B-Chat-v1.0-4bit-128g"
Expand Down
4 changes: 1 addition & 3 deletions examples/quantization/basic_usage_wikitext2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

import torch
from datasets import load_dataset
from transformers import AutoTokenizer

from gptqmodel import GPTQModel, QuantizeConfig

from transformers import AutoTokenizer

pretrained_model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quantized_model_id = "TinyLlama-1.1B-Chat-v1.0-4bit-128g"
Expand Down
1 change: 0 additions & 1 deletion examples/quantization/transformers_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig


model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = ["gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."]
Expand Down
1 change: 0 additions & 1 deletion gptqmodel/models/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ..utils.rocm import IS_ROCM
from ..utils.torch import HAS_CUDA, HAS_MPS, HAS_XPU


CPU = device("cpu")
CUDA = device("cuda")
CUDA_0 = device("cuda:0")
Expand Down
6 changes: 1 addition & 5 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import os


if not os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None):
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'expandable_segments:True'
print("ENV: Auto setting PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' for memory saving.")
Expand All @@ -29,7 +28,6 @@

import sys # noqa: E402


# TODO: waiting for pytorch implementgation of aten ops for MPS
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
Expand Down Expand Up @@ -100,7 +98,6 @@
from .definitions.xverse import XverseGPTQ # noqa: E402
from .definitions.yi import YiGPTQ # noqa: E402


# make quants and inference more determinisitc
torch.manual_seed(787)
random.seed(787)
Expand Down Expand Up @@ -311,11 +308,10 @@ def eval(
if task not in EVAL.get_task_enums():
raise ValueError(f"lm_eval support tasks: {EVAL.get_all_tasks_string()}")

from gptqmodel.utils.eval import lm_eval
from lm_eval.utils import make_table
from transformers import AutoTokenizer

from gptqmodel.utils.eval import lm_eval

tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)

model_name = 'hf' if backend == 'gptqmodel' else backend
Expand Down
41 changes: 18 additions & 23 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,22 @@
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, modeling_utils

from ..nn_modules.hooked_linear import replace_linear_with_hooked_linear
from ..nn_modules.qlinear import BaseQuantLinear
from ..quantization import GPTQ, QuantizeConfig
from ..quantization.config import FORMAT, QUANTIZE_BLACK_LIST, AutoRoundQuantizeConfig
from ..utils.backend import BACKEND
from ..utils.data import collate_data
from ..utils.device import get_cpu_usage_memory, get_gpu_usage_memory
from ..utils.importer import select_quant_linear
from ..utils.logger import setup_logger
from ..utils.model import (
MODALITY,
check_to_quantized,
find_modules,
get_device,
get_module,
get_module_by_name_prefix,
get_moe_layer_modules,
move_to,
nested_move_to,
pack_model,
)
from ..utils.model import (MODALITY, check_to_quantized, find_modules, get_device, get_module,
get_module_by_name_prefix, get_moe_layer_modules, move_to, nested_move_to, pack_model)
from ..utils.progress import ProgressBar
from ..utils.torch import torch_empty_cache
from ._const import CPU, DEFAULT_MAX_SHARD_SIZE, DEVICE, SUPPORTS_MODULE_TYPES, CALIBRATION_DATASET_CONCAT_CHAR
from ._const import CALIBRATION_DATASET_CONCAT_CHAR, CPU, DEFAULT_MAX_SHARD_SIZE, DEVICE, SUPPORTS_MODULE_TYPES
from .loader import ModelLoader
from .writer import (
QUANT_LOG_DAMP,
QUANT_LOG_FWD_TIME,
QUANT_LOG_LAYER,
QUANT_LOG_LOSS,
QUANT_LOG_MODULE,
QUANT_LOG_TIME,
ModelWriter,
)

from .writer import (QUANT_LOG_DAMP, QUANT_LOG_FWD_TIME, QUANT_LOG_LAYER,
QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter)

# pytorch 2.6.0 fixes many compilation errors
PYTORCH_MIN_VERFSION_WITH_COMPILE = Version("2.6.0")
Expand Down Expand Up @@ -142,6 +125,7 @@ def __init__(
super().__init__()

self.model = model
self.compiled = False # set to True while compile() is triggered successfully
self.quantized = quantized
self.load_quantized_model = load_quantized_model
if tokenizer is not None:
Expand Down Expand Up @@ -997,6 +981,7 @@ def compile(self, backend="inductor", mode="max-autotune"):
return self

if Version(torch.__version__) < PYTORCH_MIN_VERFSION_WITH_COMPILE:
self.compiled = False
logger.warning("To use compile(), you need to have torch version >= 2.5.1, please upgrade it by `pip install torch -U`")
return self

Expand All @@ -1006,12 +991,22 @@ def compile(self, backend="inductor", mode="max-autotune"):

try:
self.model = torch.compile(self.model, fullgraph=True, backend=backend, mode=mode)
self.compiled = True
except Exception as e:
logger.info(f"Compiling model again with `fullgraph=False`; `full-graph=True` compile failed: {e}")
try:
self.model = torch.compile(self.model, fullgraph=False, backend=backend, mode=mode)
self.compiled = True
except Exception as e:
self.compiled = False
logger.info(f"Compiling model failed: running model in non-compiled mode. {e}")

# trigger kernel compilation hooks
if self.compiled:
modules = find_modules(self.model, layers=[BaseQuantLinear])
for name in modules.keys():
modules[name].compile()

return self

def serve(self,
Expand Down
1 change: 0 additions & 1 deletion gptqmodel/models/definitions/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from ...utils.logger import setup_logger
from ..base import BaseGPTQModel


logger = setup_logger()

SUPPORT_ERR = "Currently, only vLLM/SGLang with flashinfer enabled can correctly inference a quantized Gemma2-27B model. Pre-quantized model with sample vLLM code: https://huggingface.co/ModelCloud/gemma-2-27b-it-gptq-4bit ."
Expand Down
24 changes: 5 additions & 19 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,13 @@
from ..utils.backend import BACKEND
from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear
from ..utils.logger import setup_logger
from ..utils.marlin import (
_validate_marlin_compatibility,
_validate_marlin_device_support,
prepare_model_for_marlin_load,
)
from ..utils.model import (
auto_dtype,
convert_gptq_v1_to_v2_format,
find_modules,
get_checkpoints,
get_moe_layer_modules,
gptqmodel_post_init,
load_checkpoint_in_model_then_tie_weights,
make_quant,
simple_dispatch_model,
verify_model_hash,
verify_sharded_model_hashes,
)
from ..utils.marlin import (_validate_marlin_compatibility,
_validate_marlin_device_support, prepare_model_for_marlin_load)
from ..utils.model import (auto_dtype, convert_gptq_v1_to_v2_format, find_modules, get_checkpoints,
get_moe_layer_modules, gptqmodel_post_init, load_checkpoint_in_model_then_tie_weights,
make_quant, simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes)
from ._const import DEVICE, SUPPORTED_MODELS, normalize_device


logger = setup_logger()

ATTN_IMPLEMENTATION = "attn_implementation"
Expand Down
30 changes: 6 additions & 24 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,18 @@
from transformers.models.auto.tokenization_auto import get_tokenizer_config
from transformers.utils.generic import ContextManagers

from ..quantization.config import (
FORMAT,
META_FIELD_DAMP_AUTO_INCREMENT,
META_FIELD_DAMP_PERCENT,
META_FIELD_MSE,
META_FIELD_QUANTIZER,
META_FIELD_STATIC_GROUPS,
META_FIELD_TRUE_SEQUENTIAL,
META_FIELD_URI,
META_QUANTIZER_GPTQMODEL,
META_VALUE_URI,
MIN_VERSION_WITH_V2,
)
from ..quantization.config import (FORMAT, META_FIELD_DAMP_AUTO_INCREMENT, META_FIELD_DAMP_PERCENT, META_FIELD_MSE,
META_FIELD_QUANTIZER, META_FIELD_STATIC_GROUPS, META_FIELD_TRUE_SEQUENTIAL,
META_FIELD_URI, META_QUANTIZER_GPTQMODEL, META_VALUE_URI, MIN_VERSION_WITH_V2)
from ..utils.backend import BACKEND
from ..utils.logger import setup_logger
from ..utils.model import (
convert_gptq_v2_to_v1_format,
copy_py_files,
find_modules,
get_model_files_size,
get_moe_layer_modules,
get_state_dict_for_save,
load_checkpoint_in_model_then_tie_weights,
make_quant,
)
from ..utils.model import (convert_gptq_v2_to_v1_format, copy_py_files, find_modules,
get_model_files_size, get_moe_layer_modules, get_state_dict_for_save,
load_checkpoint_in_model_then_tie_weights, make_quant)
from ..utils.torch import torch_empty_cache
from ..version import __version__
from ._const import CPU, DEFAULT_MAX_SHARD_SIZE


logger = setup_logger()

QUANT_LOG_LAYER = "layer"
Expand Down
6 changes: 5 additions & 1 deletion gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,14 @@ def validate_device(cls, device: DEVICE):
if device not in cls.SUPPORTS_DEVICES:
raise NotImplementedError(f"{cls} only supports `{cls.SUPPORTS_DEVICES}`: actual device = `{device}`")

# override me
# override me, to perform post-weight load to device init
def post_init(self):
pass

# override me, to perform any torch.compile logic on the kernel pre forward
def compile(self):
pass

class PackableQuantLinear(BaseQuantLinear):
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
Expand Down
2 changes: 0 additions & 2 deletions gptqmodel/nn_modules/qlinear/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@
import numpy as np
import torch
import torch.nn as nn

from gptqmodel.nn_modules.qlinear import PackableQuantLinear

from ...models._const import DEVICE, PLATFORM
from ...utils.logger import setup_logger


logger = setup_logger()

BITBLAS_TARGET = None
Expand Down
Loading

0 comments on commit ff72d31

Please sign in to comment.