Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp8 support moe models #2928

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Fp8Weight,
_load_scalar_or_matrix_scale,
requantize_with_max_scale,
normalize_e4m3fn_to_e4m3fnuz,
normalize_e4m3fn_to_native_float8,
)
from text_generation_server.utils.weights import Weights, WeightsLoader
from text_generation_server.utils.import_utils import SYSTEM
Expand Down Expand Up @@ -148,7 +148,7 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
)

if self.load_weight_scale and SYSTEM == "rocm":
w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8(
w, weight_scale, input_scale
)

Expand Down
12 changes: 6 additions & 6 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
return Fp8Linear


def normalize_e4m3fn_to_e4m3fnuz(
def normalize_e4m3fn_to_native_float8(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if weight.dtype == torch.float8_e4m3fn:
if weight.dtype == torch.float8_e4m3fn and SYSTEM == "rocm":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function would now not normalize on SYSTEM != "rocm" even if the data type is float8_e4m3fn. I think either the function should be renamed to normalize_e4m3fn_to_native_float8 or this condition should not be there (and do the conversion regardless SYSTEM).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done renamed the function

# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
Expand Down Expand Up @@ -162,7 +162,7 @@ def fp8_quantize(
qweight = qweight.to(qdtype)

if SYSTEM == "rocm":
qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale)
qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale)

return qweight, scale

Expand Down Expand Up @@ -285,7 +285,7 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in
)

if SYSTEM == "rocm":
w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
w, scale, input_scale = normalize_e4m3fn_to_native_float8(
w, scale, input_scale
)

Expand Down Expand Up @@ -380,8 +380,8 @@ def __init__(
if CUTLASS_FP8_AVAILABLE:
log_once(logger.info, "Using cutlass w8a8 kernels")
if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn:
qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=qweight, weight_scale=scale
qweight, scale, input_scale = normalize_e4m3fn_to_native_float8(
weight=qweight, weight_scale=scale, input_scale=input_scale
)

self.dtype = dtype
Expand Down
15 changes: 10 additions & 5 deletions server/text_generation_server/layers/moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
can_use_marlin_moe_gemm,
)
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import (
Expand Down Expand Up @@ -218,12 +219,16 @@ def __init__(
down_proj_name: str = "down_proj",
):
super().__init__()

if (
isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader):
if isinstance(weights.loader, DefaultWeightsLoader) and isinstance(
weights.loader.weight_class, UnquantizedWeight
):
cls = UnquantizedSparseMoELayer
elif isinstance(weights.loader, HybridFP8UnquantLoader):
cls = (
FP8SparseMoELayer
if weights.loader.to_fp8
else UnquantizedSparseMoELayer
)
elif isinstance(
weights.loader, GPTQMarlinWeightsLoader
) and can_use_marlin_moe_gemm(
Expand Down
162 changes: 162 additions & 0 deletions server/text_generation_server/layers/moe/fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from typing import Optional

import torch
import torch.nn as nn

from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
Fp8Weight,
fp8_quantize,
quant_dtype,
normalize_e4m3fn_to_native_float8,
)
from moe_kernels.fused_moe import fused_moe


class FP8SparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()

assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"

self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize

(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.gate_up_proj_input_scale,
) = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
)

self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
_load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
)
)

def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
a1_scale=self.gate_up_proj_input_scale,
a2_scale=self.down_proj_input_scale,
)


def _load_expert_weights(
get_weight_fn,
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
all_weight_scales = None
max_input_scale = None

for i in range(n_experts):
weight = get_weight_fn(prefix, i, name, weights)

assert isinstance(weight, Fp8Weight)

if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=quant_dtype,
device=weight.weight.device,
)
if all_weight_scales is None:
all_weight_scales = torch.empty(
(n_experts,),
dtype=torch.float32,
device=weight.weight.device,
)

if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
all_weight[i], all_weight_scales[i], current_input_scale = (
normalize_e4m3fn_to_native_float8(
weight.weight, weight.weight_scale, weight.input_scale
)
)
if current_input_scale is not None:
if max_input_scale is None or current_input_scale > max_input_scale:
max_input_scale = current_input_scale
else:
all_weight[i], all_weight_scales[i] = fp8_quantize(
weight.weight, scalar=True
)

assert all_weight is not None

return all_weight, all_weight_scales, max_input_scale


def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)

return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
)


def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}")

return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
)
12 changes: 1 addition & 11 deletions server/text_generation_server/layers/moe/unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,7 @@ def __init__(
)

def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
if SYSTEM == "rocm":
return fused_moe(
x,
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
elif SYSTEM == "ipex":
if SYSTEM == "ipex":
return self.ipex_fused_moe(
hidden_states=x,
router_logits=gating_output,
Expand Down
Loading