diff --git a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py index ebcc06d6811..42c5e633d9f 100644 --- a/server/text_generation_server/layers/compressed_tensors/w8an_fp.py +++ b/server/text_generation_server/layers/compressed_tensors/w8an_fp.py @@ -7,7 +7,7 @@ Fp8Weight, _load_scalar_or_matrix_scale, requantize_with_max_scale, - normalize_e4m3fn_to_e4m3fnuz, + normalize_e4m3fn_to_native_float8, ) from text_generation_server.utils.weights import Weights, WeightsLoader from text_generation_server.utils.import_utils import SYSTEM @@ -148,7 +148,7 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in ) if self.load_weight_scale and SYSTEM == "rocm": - w, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + w, weight_scale, input_scale = normalize_e4m3fn_to_native_float8( w, weight_scale, input_scale ) diff --git a/server/text_generation_server/layers/fp8.py b/server/text_generation_server/layers/fp8.py index 4e83ec9d0ea..67b33c980f3 100644 --- a/server/text_generation_server/layers/fp8.py +++ b/server/text_generation_server/layers/fp8.py @@ -58,12 +58,12 @@ def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]: return Fp8Linear -def normalize_e4m3fn_to_e4m3fnuz( +def normalize_e4m3fn_to_native_float8( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - if weight.dtype == torch.float8_e4m3fn: + if weight.dtype == torch.float8_e4m3fn and SYSTEM == "rocm": # The bits pattern 10000000(-128) represents zero in e4m3fn # but NaN in e4m3fnuz. So here we set it to 0. # https://onnx.ai/onnx/technical/float8.html @@ -162,7 +162,7 @@ def fp8_quantize( qweight = qweight.to(qdtype) if SYSTEM == "rocm": - qweight, scale, _ = normalize_e4m3fn_to_e4m3fnuz(qweight, scale) + qweight, scale, _ = normalize_e4m3fn_to_native_float8(qweight, scale) return qweight, scale @@ -285,7 +285,7 @@ def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: in ) if SYSTEM == "rocm": - w, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + w, scale, input_scale = normalize_e4m3fn_to_native_float8( w, scale, input_scale ) @@ -380,8 +380,8 @@ def __init__( if CUTLASS_FP8_AVAILABLE: log_once(logger.info, "Using cutlass w8a8 kernels") if SYSTEM == "rocm" and qweight.dtype == torch.float8_e4m3fn: - qweight, scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( - weight=qweight, weight_scale=scale + qweight, scale, input_scale = normalize_e4m3fn_to_native_float8( + weight=qweight, weight_scale=scale, input_scale=input_scale ) self.dtype = dtype diff --git a/server/text_generation_server/layers/moe/__init__.py b/server/text_generation_server/layers/moe/__init__.py index be40d78a8a1..3b227e96ff1 100644 --- a/server/text_generation_server/layers/moe/__init__.py +++ b/server/text_generation_server/layers/moe/__init__.py @@ -16,6 +16,7 @@ can_use_marlin_moe_gemm, ) from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer +from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.log import log_once from text_generation_server.utils.weights import ( @@ -218,12 +219,16 @@ def __init__( down_proj_name: str = "down_proj", ): super().__init__() - - if ( - isinstance(weights.loader, DefaultWeightsLoader) - and isinstance(weights.loader.weight_class, UnquantizedWeight) - ) or isinstance(weights.loader, HybridFP8UnquantLoader): + if isinstance(weights.loader, DefaultWeightsLoader) and isinstance( + weights.loader.weight_class, UnquantizedWeight + ): cls = UnquantizedSparseMoELayer + elif isinstance(weights.loader, HybridFP8UnquantLoader): + cls = ( + FP8SparseMoELayer + if weights.loader.to_fp8 + else UnquantizedSparseMoELayer + ) elif isinstance( weights.loader, GPTQMarlinWeightsLoader ) and can_use_marlin_moe_gemm( diff --git a/server/text_generation_server/layers/moe/fp8.py b/server/text_generation_server/layers/moe/fp8.py new file mode 100644 index 00000000000..7ccddb5b029 --- /dev/null +++ b/server/text_generation_server/layers/moe/fp8.py @@ -0,0 +1,162 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from text_generation_server.utils.weights import Weights +from text_generation_server.layers.fp8 import ( + Fp8Weight, + fp8_quantize, + quant_dtype, + normalize_e4m3fn_to_native_float8, +) +from moe_kernels.fused_moe import fused_moe + + +class FP8SparseMoELayer(nn.Module): + def __init__( + self, + *, + n_expert_group: Optional[int], + n_experts: int, + prefix: str, + renormalize: bool, + topk: int, + topk_group: Optional[int], + weights: Weights, + gate_proj_name: str = "gate_proj", + up_proj_name: str = "up_proj", + down_proj_name: str = "down_proj", + ): + super().__init__() + + assert (n_expert_group is None) == ( + topk_group is None + ), "n_expert_group and topk_group must both be None or have some value" + + self.n_expert_group = n_expert_group + self.topk = topk + self.topk_group = topk_group + self.renormalize = renormalize + + ( + self.gate_up_proj, + self.gate_up_proj_weight_scale, + self.gate_up_proj_input_scale, + ) = _load_expert_multi_weights_col( + prefix=prefix, + n_experts=n_experts, + gate_proj_name=gate_proj_name, + up_proj_name=up_proj_name, + weights=weights, + ) + + self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = ( + _load_expert_weights_row( + prefix=prefix, + n_experts=n_experts, + name=down_proj_name, + weights=weights, + ) + ) + + def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: + return fused_moe( + x, + w1=self.gate_up_proj, + w2=self.down_proj, + gating_output=gating_output, + topk=self.topk, + renormalize=self.renormalize, + inplace=True, + use_grouped_topk=self.n_expert_group is not None, + num_expert_group=self.n_expert_group, + topk_group=self.topk_group, + use_fp8_w8a8=True, + w1_scale=self.gate_up_proj_weight_scale, + w2_scale=self.down_proj_weight_scale, + a1_scale=self.gate_up_proj_input_scale, + a2_scale=self.down_proj_input_scale, + ) + + +def _load_expert_weights( + get_weight_fn, + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + all_weight = None + all_weight_scales = None + max_input_scale = None + + for i in range(n_experts): + weight = get_weight_fn(prefix, i, name, weights) + + assert isinstance(weight, Fp8Weight) + + if all_weight is None: + all_weight = torch.empty( + (n_experts,) + weight.weight.shape, + dtype=quant_dtype, + device=weight.weight.device, + ) + if all_weight_scales is None: + all_weight_scales = torch.empty( + (n_experts,), + dtype=torch.float32, + device=weight.weight.device, + ) + + if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}: + all_weight[i], all_weight_scales[i], current_input_scale = ( + normalize_e4m3fn_to_native_float8( + weight.weight, weight.weight_scale, weight.input_scale + ) + ) + if current_input_scale is not None: + if max_input_scale is None or current_input_scale > max_input_scale: + max_input_scale = current_input_scale + else: + all_weight[i], all_weight_scales[i] = fp8_quantize( + weight.weight, scalar=True + ) + + assert all_weight is not None + + return all_weight, all_weight_scales, max_input_scale + + +def _load_expert_multi_weights_col( + *, + prefix: str, + n_experts: int, + gate_proj_name: str, + up_proj_name: str, + weights: Weights, +) -> torch.Tensor: + def get_weight_fn(prefix, i, name, weights): + return weights.get_multi_weights_col( + [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0 + ) + + return _load_expert_weights( + get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights + ) + + +def _load_expert_weights_row( + *, + prefix: str, + n_experts: int, + name: str, + weights: Weights, +) -> torch.Tensor: + def get_weight_fn(prefix, i, name, weights): + return weights.get_weights_row(f"{prefix}.{i}.{name}") + + return _load_expert_weights( + get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights + ) diff --git a/server/text_generation_server/layers/moe/unquantized.py b/server/text_generation_server/layers/moe/unquantized.py index 3c9bcabace2..32326653eb7 100644 --- a/server/text_generation_server/layers/moe/unquantized.py +++ b/server/text_generation_server/layers/moe/unquantized.py @@ -58,17 +58,7 @@ def __init__( ) def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: - if SYSTEM == "rocm": - return fused_moe( - x, - self.gate_up_proj, - self.down_proj, - gating_output, - self.topk, - renormalize=self.renormalize, - inplace=True, - ) - elif SYSTEM == "ipex": + if SYSTEM == "ipex": return self.ipex_fused_moe( hidden_states=x, router_logits=gating_output,