diff --git a/flake.lock b/flake.lock index 76b4ca2fe38..1706385a155 100644 --- a/flake.lock +++ b/flake.lock @@ -978,16 +978,16 @@ "nixpkgs": "nixpkgs_6" }, "locked": { - "lastModified": 1729531056, - "narHash": "sha256-dW9IOA31+j3VS19WAWAmkJW2YCzeVZGqd6HpIJfODtI=", + "lastModified": 1729761651, + "narHash": "sha256-GYykQ9Fxji2EuXCGcPn0dx8Qx8VQBJTkRdcCytp4A/k=", "owner": "huggingface", "repo": "text-generation-inference-nix", - "rev": "a84a90281a17b15762873845c947e5c78f5a8dd1", + "rev": "f7e3c4fa67d70590ed9ee47feeab645bd9ba81b1", "type": "github" }, "original": { "owner": "huggingface", - "ref": "marlin-kernels-0.3.0", + "ref": "marlin-kernels-0.3.1", "repo": "text-generation-inference-nix", "type": "github" } diff --git a/flake.nix b/flake.nix index 5c05bfae7fb..45441caeec6 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ inputs.nixpkgs.follows = "tgi-nix/nixpkgs"; }; nix-filter.url = "github:numtide/nix-filter"; - tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0"; + tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.1"; nixpkgs.follows = "tgi-nix/nixpkgs"; flake-utils.url = "github:numtide/flake-utils"; rust-overlay = { diff --git a/nix/server.nix b/nix/server.nix index 7406d563559..4091554691a 100644 --- a/nix/server.nix +++ b/nix/server.nix @@ -8,7 +8,6 @@ eetq, einops, exllamav2, - fbgemm-gpu, flashinfer, flash-attn, flash-attn-layer-norm, @@ -77,7 +76,6 @@ buildPythonPackage { causal-conv1d einops exllamav2 - fbgemm-gpu flashinfer flash-attn flash-attn-layer-norm diff --git a/server/Makefile b/server/Makefile index 18424dd6d7e..ec004640593 100644 --- a/server/Makefile +++ b/server/Makefile @@ -5,7 +5,6 @@ include Makefile-awq include Makefile-eetq include Makefile-selective-scan include Makefile-lorax-punica -include Makefile-fbgemm include Makefile-exllamav2 include Makefile-flashinfer diff --git a/server/Makefile-fbgemm b/server/Makefile-fbgemm deleted file mode 100644 index 3b8061a1fc4..00000000000 --- a/server/Makefile-fbgemm +++ /dev/null @@ -1,15 +0,0 @@ -fbgemm_commit := v0.8.0 - -build-fbgemm: - @if [ ! -d "fbgemm" ]; then \ - git clone https://github.com/pytorch/FBGEMM.git fbgemm; \ - fi - cd fbgemm && git fetch && git checkout $(fbgemm_commit) && \ - git submodule update --init --recursive && \ - cd fbgemm_gpu && \ - pip install -r requirements.txt && \ - CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai build - -install-fbgemm: build-fbgemm - cd fbgemm/fbgemm_gpu && \ - CUDA_ARCH_LIST="8.0;9.0a" NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_90a,code=sm_90a" TORCH_CUDA_ARCH_LIST="8.0;9.0a" python setup.py --package_variant genai install diff --git a/server/poetry.lock b/server/poetry.lock index 1293e883656..e75786c3383 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "accelerate" @@ -1215,12 +1215,12 @@ files = [ [[package]] name = "marlin-kernels" -version = "0.3.0" +version = "0.3.1" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:a2086b9e98d22071f52c5b4b4b98b1b4a988565258905173fa74c5a9eddd1a0a"}, + {file = "marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", hash = "sha256:705c89ed54977099a40b37dc0c796964649024f1a8819a1832118cd7b146efe1"}, ] [package.dependencies] @@ -1228,16 +1228,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.0" +version = "0.3.1" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:f39a6946d8247629446ec170832d832c7038c363f1d8803211fe67249c2d804d"}, + {file = "marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", hash = "sha256:e1f3d123eca643149d0a4f6b81c4405d78abb3a694a78fccc8670a25b3404406"}, ] [package.dependencies] @@ -1245,16 +1245,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.0" +version = "0.3.1" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:07fd869d5289777fa866107dae676523e18b1f6ba4afce79946ddc58a6870169"}, + {file = "marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", hash = "sha256:9d68367fd5e1caf2edc90b77ad5d074b11586012265a3147ecca1f1171ae22f8"}, ] [package.dependencies] @@ -1262,16 +1262,16 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl" [[package]] name = "marlin-kernels" -version = "0.3.0" +version = "0.3.1" description = "Marlin quantization kernels" optional = true python-versions = ">=3.7" files = [ - {file = "marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:0dedaa418225d490a5f1d8f85dbc75e439a8c43a8870e4ef32945bf61672d7dc"}, + {file = "marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", hash = "sha256:d962277c5f7642972e298650913dd0546b9f735b706dc88bb34955b3cac7f330"}, ] [package.dependencies] @@ -1279,7 +1279,7 @@ torch = "*" [package.source] type = "url" -url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl" +url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl" [[package]] name = "mdurl" @@ -1770,6 +1770,7 @@ description = "Nvidia JIT LTO Library" optional = true python-versions = ">=3" files = [ + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, ] @@ -3966,4 +3967,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "500fa44255e4a6c89a16314a931548447afe1ba71ea341a73cad6670e46ddac7" +content-hash = "b39033e573f50a0f046787aebf1702d86673aad0b2fcee818404fcea7f644b81" diff --git a/server/pyproject.toml b/server/pyproject.toml index d08d0b8f488..5c414d6e0ec 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0" numpy = "^1.26" marlin-kernels = [ - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, - { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true }, + { url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true }, ] moe-kernels = [ { url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true }, diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index a58c7f7b223..216881739e9 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -1,7 +1,8 @@ -import torch - from dataclasses import dataclass -from typing import Optional, Tuple, Union, List +import os +from typing import Optional, Tuple, Type, Union, List + +import torch from loguru import logger from text_generation_server.utils.import_utils import SYSTEM @@ -11,20 +12,7 @@ UnquantizedWeight, Weights, ) -from text_generation_server.utils.log import log_master, log_once -import importlib.util - - -FBGEMM_MM_AVAILABLE = False -FBGEMM_DYN_AVAILABLE = False - - -def is_fbgemm_gpu_available(): - try: - return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None - except ModuleNotFoundError: - return False - +from text_generation_server.utils.log import log_once try: import marlin_kernels @@ -32,23 +20,26 @@ def is_fbgemm_gpu_available(): marlin_kernels = None -if is_fbgemm_gpu_available(): - if SYSTEM == "cuda": - major, _ = torch.cuda.get_device_capability() - FBGEMM_MM_AVAILABLE = major == 9 - FBGEMM_DYN_AVAILABLE = major >= 8 +if SYSTEM == "cuda" and marlin_kernels is not None: + major, minor = torch.cuda.get_device_capability() + CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8( + major * 10 + minor + ) else: - log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") + CUTLASS_FP8_AVAILABLE = False -def get_fp8_linear() -> torch.nn.Module: +def get_fp8_linear() -> Type[torch.nn.Module]: """ Return an FP8 linear `Module` that is compatible with the current system. """ if SYSTEM == "cuda": + major, _ = torch.cuda.get_device_capability() - if major == 8: + if major == 8 and os.getenv("USE_CUTLASS_W8A8", "0") != "1": + # NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin + # gives better decoding throughput on L4 and L40. from text_generation_server.layers.marlin import GPTQMarlinFP8Linear return GPTQMarlinFP8Linear @@ -94,12 +85,6 @@ def fp8_quantize( argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can be used without modification). """ - if FBGEMM_DYN_AVAILABLE and not scalar and not scale: - qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( - weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype - ) - return qweight, scale - if marlin_kernels is not None: shape = weight.shape qweight, scale = marlin_kernels.scaled_fp8_quant( @@ -107,11 +92,12 @@ def fp8_quantize( dtype=qdtype, scale=scale, scale_ub=scale_upper_bound, + # TODO: don't do this when we have to use the Torch kernel. + use_per_token_if_dynamic=not scalar, ) return qweight.reshape(shape), scale - # weight, scale = quant_weights(weight, torch.int8, False) finfo = torch.finfo(qdtype) if scale is None: @@ -327,8 +313,8 @@ def __init__( scale_upper_bound: Optional[float] = None, ) -> None: super().__init__() - if FBGEMM_MM_AVAILABLE: - log_once(logger.info, "Using FBGEMM fp8 optimized kernels") + if CUTLASS_FP8_AVAILABLE: + log_once(logger.info, "Using cutlass w8a8 kernels") if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=qweight, weight_scale=scale @@ -339,13 +325,9 @@ def __init__( self.scale = scale.float() self.input_scale = input_scale.float() if input_scale is not None else None - if FBGEMM_MM_AVAILABLE: - self.scale_upper_bound = ( - torch.tensor( - [scale_upper_bound], dtype=torch.float32, device=qweight.device - ) - if scale_upper_bound is not None - else None + if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None: + self.scale_upper_bound = torch.tensor( + scale_upper_bound, dtype=torch.float32, device=qweight.device ) else: self.scale_upper_bound = scale_upper_bound @@ -354,7 +336,7 @@ def __init__( @classmethod def from_unquant(cls, weight, bias, dtype): - qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) + qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE) return cls( qweight=qweight, scale=scale, @@ -376,9 +358,6 @@ def from_fp8( input_scale = kwargs.get("input_scale", None) scale_upper_bound = kwargs.get("scale_upper_bound", None) - if FBGEMM_DYN_AVAILABLE: - # fbgemm needs float32 scales. - scale = scale.float() return cls( qweight=weight, scale=scale, @@ -397,20 +376,14 @@ def get_shared_device_identity(cls, device): return cls._device_identity_cache[device] def forward(self, input: torch.Tensor) -> torch.Tensor: - if FBGEMM_MM_AVAILABLE: + if CUTLASS_FP8_AVAILABLE: + # cutlass FP8 supports per-token scales, so get non-scalar scales. qinput, scale = fp8_quantize( - input, scale_upper_bound=self.scale_upper_bound + input, scale_upper_bound=self.scale_upper_bound, scalar=False ) - - y = torch.ops.fbgemm.f8f8bf16_rowwise( - qinput, - self.qweight, - scale, - self.scale, - use_fast_accum=True, - bias=self.bias, + return marlin_kernels.cutlass_scaled_mm( + qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias ) - return y.to(self.dtype) qinput, scale = fp8_quantize( input, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d30154083f5..f4fa431c30e 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -410,12 +410,6 @@ def get_model( else: # These quantizers only work with float16 params. dtype = torch.float16 - elif quantize == "fp8": - from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE - - if FBGEMM_DYN_AVAILABLE: - # fbgemm kernels are fp8xfp8->bf16 - dtype = torch.bfloat16 else: # Keep it as default for now and let # every model resolve their own default dtype.