diff --git a/examples/fp8/quantizer/README.md b/examples/fp8/quantizer/README.md index 0b6944f688b4..d0895e97dc34 100644 --- a/examples/fp8/quantizer/README.md +++ b/examples/fp8/quantizer/README.md @@ -1,6 +1,6 @@ ### Quantizer Utilities -`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM: -`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py` +`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported +from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py) ### Prerequisite diff --git a/tests/models/test_modelopt.py b/tests/models/test_modelopt.py new file mode 100644 index 000000000000..e643b115d0ea --- /dev/null +++ b/tests/models/test_modelopt.py @@ -0,0 +1,79 @@ +# flake8: noqa +"""Tests Model Optimizer fp8 models against ground truth generation +Note: these tests will only pass on H100 +""" +import os +from typing import List + +import pytest +from transformers import AutoTokenizer + +from tests.quantization.utils import is_quant_method_supported +from vllm import LLM, SamplingParams + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +MODELS = ["nvidia/Llama-3.1-8B-Instruct-FP8"] + +EXPECTED_STRS_MAP = { + "nvidia/Llama-3.1-8B-Instruct-FP8": [ + "You're referring to VLLM, a high-performance Large Language Model (LLM) inference and", + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'The comparison between artificial intelligence (AI) and human intelligence in terms of processing information is a complex and', + 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', + '**The Spark of Imagination**\n\nZeta-5, a sleek and efficient robot, whir', + 'The COVID-19 pandemic has had a profound impact on global economic structures and business models, leading to', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** 「早起きは早く獲物をとる' + ] +} + + +# This test compares against golden strings for exact match since +# there is no baseline implementation to compare against +# and is unstable w.r.t specifics of the fp8 implementation or +# the hardware being run on. +# Disabled to prevent it from breaking the build +@pytest.mark.skip( + reason= + "Prevent unstable test based on golden strings from breaking the build.") +@pytest.mark.skipif(not is_quant_method_supported("fp8"), + reason="fp8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_name", MODELS) +def test_models(example_prompts, model_name) -> None: + model = LLM( + model=model_name, + max_model_len=MAX_MODEL_LEN, + trust_remote_code=True, + enforce_eager=True, + quantization="modelopt", + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + formatted_prompts = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + tokenize=False, + add_generation_prompt=True) + for prompt in example_prompts + ] + params = SamplingParams(max_tokens=20, temperature=0) + generations: List[str] = [] + # Note: these need to be run 1 at a time due to numerical precision, + # since the expected strs were generated this way. + for prompt in formatted_prompts: + outputs = model.generate(prompt, params) + generations.append(outputs[0].outputs[0].text) + del model + + print(model_name, generations) + expected_strs = EXPECTED_STRS_MAP[model_name] + for i in range(len(example_prompts)): + generated_str = generations[i] + expected_str = expected_strs[i] + assert expected_str == generated_str, ( + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") diff --git a/vllm/config.py b/vllm/config.py index 8f5e02e35f28..1236678c821e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -282,7 +282,7 @@ def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["awq", "gptq", "fp8"] optimized_quantization_methods = [ - "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", + "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", "experts_int8" ] diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b997507ea738..cea768469aeb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -26,7 +26,8 @@ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", - "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod" + "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod", + "ModelOptFp8LinearMethod" ] diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index aa5c288962d9..b212ece1d11d 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.neuron_quant import ( NeuronQuantConfig) +from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig @@ -34,6 +35,7 @@ "tpu_int8": Int8TpuConfig, "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, + "modelopt": ModelOptFp8Config, # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py new file mode 100644 index 000000000000..dc5f47eb9b0f --- /dev/null +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -0,0 +1,163 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) + +logger = init_logger(__name__) + +ACTIVATION_SCHEMES = ["static"] + + +class ModelOptFp8Config(QuantizationConfig): + """Config class for ModelOpt FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" + " the format is experimental and could change.") + + @classmethod + def get_name(cls) -> str: + return "modelopt" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + is_checkpoint_fp8_serialized = ("FP8" in quant_method) + if not is_checkpoint_fp8_serialized: + raise ValueError("ModelOpt currently only supports static FP8" + "quantization in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") + return cls(is_checkpoint_fp8_serialized) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): + return ModelOptFp8LinearMethod(self) + elif isinstance(layer, Attention): + return ModelOptFp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + super().__init__(quant_config) + + +class ModelOptFp8LinearMethod(LinearMethodBase): + """Linear method for Model Optimizer static quantization. + Supports loading FP8 checkpoints with static weight scale and + activation scale. Future support might be added for dynamic + scales. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn datatype + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + # INPUT SCALE + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + + def process_weights_after_loading(self, layer: Module) -> None: + max_w_scale, weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths) + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 075451292a8e..5051d45dd115 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -192,6 +192,13 @@ def get_quant_config(model_config: ModelConfig, if model_config.quantization == "bitsandbytes": config["adapter_name_or_path"] = model_name_or_path + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}.") return quant_cls.from_config(config)