Skip to content

Commit

Permalink
Switch from fbgemm-gpu w8a8 scaled matmul to vLLM/marlin-kernels
Browse files Browse the repository at this point in the history
Performance and accuracy of these kernels are on par (tested with Llama
70B and 405B). Removes a dependency and resolves some stability issues
we have been seeing.
  • Loading branch information
danieldk committed Oct 24, 2024
1 parent eab07f7 commit 197d45e
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 103 deletions.
8 changes: 4 additions & 4 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
inputs.nixpkgs.follows = "tgi-nix/nixpkgs";
};
nix-filter.url = "github:numtide/nix-filter";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.0";
tgi-nix.url = "github:huggingface/text-generation-inference-nix/marlin-kernels-0.3.1";
nixpkgs.follows = "tgi-nix/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
rust-overlay = {
Expand Down
2 changes: 0 additions & 2 deletions nix/server.nix
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
eetq,
einops,
exllamav2,
fbgemm-gpu,
flashinfer,
flash-attn,
flash-attn-layer-norm,
Expand Down Expand Up @@ -77,7 +76,6 @@ buildPythonPackage {
causal-conv1d
einops
exllamav2
fbgemm-gpu
flashinfer
flash-attn
flash-attn-layer-norm
Expand Down
1 change: 0 additions & 1 deletion server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ include Makefile-awq
include Makefile-eetq
include Makefile-selective-scan
include Makefile-lorax-punica
include Makefile-fbgemm
include Makefile-exllamav2
include Makefile-flashinfer

Expand Down
15 changes: 0 additions & 15 deletions server/Makefile-fbgemm

This file was deleted.

29 changes: 15 additions & 14 deletions server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ py-cpuinfo = "^9.0.0"
numpy = "^1.26"

marlin-kernels = [
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.0/marlin_kernels-0.3.0+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp310-cp310-linux_x86_64.whl", python = "~3.10", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp311-cp311-linux_x86_64.whl", python = "~3.11", optional = true },
{ url = "https://github.com/danieldk/marlin-kernels/releases/download/v0.3.1/marlin_kernels-0.3.1+cu123torch2.4-cp312-cp312-linux_x86_64.whl", python = "~3.12", optional = true },
]
moe-kernels = [
{ url = "https://github.com/danieldk/moe-kernels/releases/download/v0.6.0/moe_kernels-0.6.0+cu123torch2.4-cp39-cp39-linux_x86_64.whl", python = "~3.9", optional = true },
Expand Down
85 changes: 29 additions & 56 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch

from dataclasses import dataclass
from typing import Optional, Tuple, Union, List
import os
from typing import Optional, Tuple, Type, Union, List

import torch
from loguru import logger

from text_generation_server.utils.import_utils import SYSTEM
Expand All @@ -11,44 +12,34 @@
UnquantizedWeight,
Weights,
)
from text_generation_server.utils.log import log_master, log_once
import importlib.util


FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False


def is_fbgemm_gpu_available():
try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False

from text_generation_server.utils.log import log_once

try:
import marlin_kernels
except ImportError:
marlin_kernels = None


if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
if SYSTEM == "cuda" and marlin_kernels is not None:
major, minor = torch.cuda.get_device_capability()
CUTLASS_FP8_AVAILABLE = marlin_kernels.cutlass_scaled_mm_supports_fp8(
major * 10 + minor
)
else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
CUTLASS_FP8_AVAILABLE = False


def get_fp8_linear() -> torch.nn.Module:
def get_fp8_linear() -> Type[torch.nn.Module]:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""

if SYSTEM == "cuda":

major, _ = torch.cuda.get_device_capability()
if major == 8:
if major == 8 and os.getenv("USE_CUTLASS_W8A8", "0") != "1":
# NOTE: Capability 8.9 is supported by cutlass kernels, but FP8-Marlin
# gives better decoding throughput on L4 and L40.
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

return GPTQMarlinFP8Linear
Expand Down Expand Up @@ -94,24 +85,19 @@ def fp8_quantize(
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
if FBGEMM_DYN_AVAILABLE and not scalar and not scale:
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
)
return qweight, scale

if marlin_kernels is not None:
shape = weight.shape
qweight, scale = marlin_kernels.scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
dtype=qdtype,
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
use_per_token_if_dynamic=not scalar,
)

return qweight.reshape(shape), scale

# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype)

if scale is None:
Expand Down Expand Up @@ -327,8 +313,8 @@ def __init__(
scale_upper_bound: Optional[float] = None,
) -> None:
super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
Expand All @@ -339,13 +325,9 @@ def __init__(
self.scale = scale.float()
self.input_scale = input_scale.float() if input_scale is not None else None

if FBGEMM_MM_AVAILABLE:
self.scale_upper_bound = (
torch.tensor(
[scale_upper_bound], dtype=torch.float32, device=qweight.device
)
if scale_upper_bound is not None
else None
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
self.scale_upper_bound = torch.tensor(
scale_upper_bound, dtype=torch.float32, device=qweight.device
)
else:
self.scale_upper_bound = scale_upper_bound
Expand All @@ -354,7 +336,7 @@ def __init__(

@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
return cls(
qweight=qweight,
scale=scale,
Expand All @@ -376,9 +358,6 @@ def from_fp8(
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)

if FBGEMM_DYN_AVAILABLE:
# fbgemm needs float32 scales.
scale = scale.float()
return cls(
qweight=weight,
scale=scale,
Expand All @@ -397,20 +376,14 @@ def get_shared_device_identity(cls, device):
return cls._device_identity_cache[device]

def forward(self, input: torch.Tensor) -> torch.Tensor:
if FBGEMM_MM_AVAILABLE:
if CUTLASS_FP8_AVAILABLE:
# cutlass FP8 supports per-token scales, so get non-scalar scales.
qinput, scale = fp8_quantize(
input, scale_upper_bound=self.scale_upper_bound
input, scale_upper_bound=self.scale_upper_bound, scalar=False
)

y = torch.ops.fbgemm.f8f8bf16_rowwise(
qinput,
self.qweight,
scale,
self.scale,
use_fast_accum=True,
bias=self.bias,
return marlin_kernels.cutlass_scaled_mm(
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
)
return y.to(self.dtype)

qinput, scale = fp8_quantize(
input,
Expand Down
6 changes: 0 additions & 6 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,6 @@ def get_model(
else:
# These quantizers only work with float16 params.
dtype = torch.float16
elif quantize == "fp8":
from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE

if FBGEMM_DYN_AVAILABLE:
# fbgemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16
else:
# Keep it as default for now and let
# every model resolve their own default dtype.
Expand Down

0 comments on commit 197d45e

Please sign in to comment.