From 2f9d561cf732cce89653f43a58ed37fe3b5b7262 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 5 Jun 2024 10:58:50 -0700 Subject: [PATCH] [Model] Correct Mixtral FP8 checkpoint loading (#5231) --- .../model_executor/layers/quantization/fp8.py | 7 +- vllm/model_executor/models/mixtral.py | 108 ++++++++++++------ 2 files changed, 80 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b084b9cee498..bf3a59e3d709 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -300,14 +300,15 @@ def all_close_1d(x: torch.Tensor) -> bool: def per_tensor_quantize(tensor: torch.Tensor, - inv_scale: float) -> torch.Tensor: + inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn) qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) return qweight.to(torch.float8_e4m3fn) -def per_tensor_dequantize(tensor: torch.Tensor, - inv_scale: float) -> torch.Tensor: +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, + torch.Tensor]) -> torch.Tensor: fake_qweight = tensor.to(torch.float16) dq_weight = fake_qweight * inv_scale return dq_weight diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 2f4237339486..0f82549780ba 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -41,7 +41,9 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, + per_tensor_dequantize, + per_tensor_quantize) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -98,16 +100,16 @@ def __init__( if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn - self.w13_weight = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - dtype=params_dtype)) - self.w2_weight = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - dtype=params_dtype)) + self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + dtype=params_dtype), + requires_grad=False) + self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + dtype=params_dtype), + requires_grad=False) set_weight_attrs(self.w13_weight, { "weight_loader": self.weight_loader, @@ -124,7 +126,10 @@ def __init__( if self.use_fp8: # WEIGHT_SCALE (for fp8) + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, + 2, dtype=torch.float32), requires_grad=False) self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, @@ -148,11 +153,11 @@ def __init__( raise ValueError( "Found static activation scheme for checkpoint that " "was not serialized fp8.") - self.a13_scale = nn.Parameter(torch.zeros( + self.a13_scale = nn.Parameter(torch.ones( self.num_total_experts, dtype=torch.float32), requires_grad=False) - self.a2_scale = nn.Parameter(torch.zeros( - self.num_total_experts, dtype=torch.float32), + self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), requires_grad=False) set_weight_attrs(self.a13_scale, { @@ -175,8 +180,22 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] - if "act_scale" in weight_name or "weight_scale" in weight_name: + + # Loading scales + if "act_scale" in weight_name or "w2.weight_scale" in weight_name: + if param_data[expert_id] != 1 and (param_data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "act_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}") param_data[expert_id] = loaded_weight + elif "weight_scale" in weight_name: + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + assert "w1" in weight_name or "w3" in weight_name + shard_id = 0 if "w1" in weight_name else 1 + param_data[expert_id][shard_id] = loaded_weight def process_weights_after_loading(self): # Fp8 is the only case where we need to process after loading. @@ -189,6 +208,12 @@ def process_weights_after_loading(self): dtype=torch.float8_e4m3fn) w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) for expert in range(self.num_total_experts): w13_weight[expert, :, :], self.w13_scale[ expert] = ops.scaled_fp8_quant( @@ -199,25 +224,44 @@ def process_weights_after_loading(self): self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) - # If checkpoint is fp8 + static, cleanup act_scales. - # Since state_dict has an act_scale per expert but our kernels - # are passed one act_scale shared across all experts. - elif self.quant_config.activation_scheme == "static": - if self.a13_scale is None or self.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None.") + else: + # If checkpoint is fp8 + static, cleanup act_scales. + # Since state_dict has an act_scale per expert but our kernels + # are passed one act_scale shared across all experts. + if self.quant_config.activation_scheme == "static": + if self.a13_scale is None or self.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") - if (not all_close_1d(self.a13_scale) - or not all_close_1d(self.a2_scale)): - print_warning_once( - "Found act_scales that are not equal for fp8 MoE layer. " - "Using the maximum across experts for each layer. ") + if (not all_close_1d(self.a13_scale) + or not all_close_1d(self.a2_scale)): + print_warning_once( + "Found act_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") - self.a13_scale = nn.Parameter(self.a13_scale.max(), - requires_grad=False) - self.a2_scale = nn.Parameter(self.a2_scale.max(), - requires_grad=False) + self.a13_scale = nn.Parameter(self.a13_scale.max(), + requires_grad=False) + self.a2_scale = nn.Parameter(self.a2_scale.max(), + requires_grad=False) + + assert self.w13_scale is not None + shard_size = self.intermediate_size + max_w13_scales = self.w13_scale.max(dim=1).values + for expert_id in range(self.num_total_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + self.w13_weight[expert_id][start:start + + shard_size, :], + self.w13_scale[expert_id][shard_id]) + self.w13_weight[expert_id][ + start:start + shard_size, :] = per_tensor_quantize( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape