diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index c06af04d52769..c3ff706c37e58 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -4,7 +4,6 @@ import torch from vllm import _custom_ops as ops - from vllm.triton_utils import HAS_TRITON if HAS_TRITON: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7c021fb30238c..a896127430883 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -6,13 +6,6 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger - -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, - fused_moe) - from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -23,11 +16,16 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, apply_fp8_linear, create_per_tensor_scale_param, - cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale) + apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported, + requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import print_warning_once +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.quantization.fp8_fused_moe import ( + Fp8MoEMethod) ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -232,189 +230,6 @@ def apply(self, cutlass_fp8_supported=self.cutlass_fp8_supported, use_per_token_if_dynamic=False) -if HAS_TRITON: - - class Fp8MoEMethod(FusedMoEMethodBase): - """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. - - Also supports loading quantized FP16/BF16 model checkpoints with dynamic - activation scaling. The weight scaling factor will be initialized after - the model weights are loaded. - - Args: - quant_config: The quantization config. - """ - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - - def create_weights(self, layer: Module, num_experts: int, hidden_size: int, - intermediate_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_scale = torch.nn.Parameter(torch.ones(num_experts, - 2, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_scale", w13_scale) - - w2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_scale", w2_scale) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8.") - - a13_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("a13_scale", a13_scale) - set_weight_attrs(a13_scale, extra_weight_attrs) - - a2_scale = torch.nn.Parameter(torch.ones(num_experts, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("a2_scale", a2_scale) - set_weight_attrs(a2_scale, extra_weight_attrs) - else: - layer.a13_scale = None - layer.a2_scale = None - - def process_weights_after_loading(self, layer: Module) -> None: - - # If checkpoint is fp16, quantize in place. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like(layer.w13_weight.data, - dtype=torch.float8_e4m3fn) - w2_weight = torch.empty_like(layer.w2_weight.data, - dtype=torch.float8_e4m3fn) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - layer.w13_scale = torch.nn.Parameter(torch.ones( - layer.num_experts, - dtype=torch.float32, - device=w13_weight.device), - requires_grad=False) - for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_scale[ - expert] = ops.scaled_fp8_quant( - layer.w13_weight.data[expert, :, :]) - w2_weight[expert, :, :], layer.w2_scale[ - expert] = ops.scaled_fp8_quant( - layer.w2_weight.data[expert, :, :]) - layer.w13_weight = torch.nn.Parameter(w13_weight, - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, - requires_grad=False) - return - - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. - else: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if self.quant_config.activation_scheme == "static": - if layer.a13_scale is None or layer.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.a13_scale) - or not all_close_1d(layer.a2_scale)): - print_warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer. ") - layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), - requires_grad=False) - layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), - requires_grad=False) - - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_scale.max(dim=1).values - for expert_id in range(layer.num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size - - layer.w13_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) - return - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - - return fused_moe(x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_fp8=True, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) - class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/layers/quantization/fp8_fused_moe.py b/vllm/model_executor/layers/quantization/fp8_fused_moe.py new file mode 100644 index 0000000000000..fcc5d884d6fb2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp8_fused_moe.py @@ -0,0 +1,196 @@ +from typing import Optional + +import torch +from torch.nn import Module + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, + fused_moe) +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, per_tensor_dequantize) +from vllm.utils import print_warning_once + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, + quant_config: "Fp8Config" # type: ignore + ): + self.quant_config = quant_config + + def create_weights(self, layer: Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_scale", w2_scale) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + + a13_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("a13_scale", a13_scale) + set_weight_attrs(a13_scale, extra_weight_attrs) + + a2_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("a2_scale", a2_scale) + set_weight_attrs(a2_scale, extra_weight_attrs) + else: + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(layer.w2_weight.data, + dtype=torch.float8_e4m3fn) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_scale = torch.nn.Parameter(torch.ones( + layer.num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.a13_scale is None or layer.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.a13_scale) + or not all_close_1d(layer.a2_scale)): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(), + requires_grad=False) + layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(), + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + return + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: + + return fused_moe(x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 4968272d37e0c..0f193559e867f 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -7,7 +7,8 @@ from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits + from vllm.model_executor.layers.ops.sample import ( + get_num_triton_sampler_splits) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData, SequenceGroupMetadata diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 6a8c71878eaa4..4572b45ceaeca 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -1,5 +1,11 @@ -from vllm.triton_utils.custom_cache_manager import ( - maybe_set_triton_cache_manager) from vllm.triton_utils.importing import HAS_TRITON +if HAS_TRITON: + from vllm.triton_utils.custom_cache_manager import ( + maybe_set_triton_cache_manager) + __all__ = ["HAS_TRITON", "maybe_set_triton_cache_manager"] + +if not HAS_TRITON: + # need to do this afterwards due to ruff complaining + __all__.pop() diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py index eb9a4a1d08153..17039d7ba24c7 100644 --- a/vllm/triton_utils/custom_cache_manager.py +++ b/vllm/triton_utils/custom_cache_manager.py @@ -1,9 +1,9 @@ import os -from vllm.logger import init_logger - +from triton.runtime.cache import (FileCacheManager, default_cache_dir, + default_dump_dir, default_override_dir) -from .importing import HAS_TRITON +from vllm.logger import init_logger logger = init_logger(__name__) @@ -18,41 +18,36 @@ def maybe_set_triton_cache_manager() -> None: os.environ["TRITON_CACHE_MANAGER"] = manager -if HAS_TRITON: - - from triton.runtime.cache import (FileCacheManager, default_cache_dir, - default_dump_dir, default_override_dir) - - class CustomCacheManager(FileCacheManager): - """Re-implements Triton's cache manager, ensuring that a - unique cache directory is created for each process. This is - needed to avoid collisions when running with tp>1 and - using multi-processing as the distributed backend. - - Note this issue was fixed by triton-lang/triton/pull/4295, - but the fix is not yet included in triton==v3.0.0. However, - it should be included in the subsequent version. - """ - - def __init__(self, key, override=False, dump=False): - self.key = key - self.lock_path = None - if dump: - self.cache_dir = default_dump_dir() +class CustomCacheManager(FileCacheManager): + """Re-implements Triton's cache manager, ensuring that a + unique cache directory is created for each process. This is + needed to avoid collisions when running with tp>1 and + using multi-processing as the distributed backend. + + Note this issue was fixed by triton-lang/triton/pull/4295, + but the fix is not yet included in triton==v3.0.0. However, + it should be included in the subsequent version. + """ + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", + "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) - elif override: - self.cache_dir = default_override_dir() - self.cache_dir = os.path.join(self.cache_dir, self.key) else: - # create cache directory if it doesn't exist - self.cache_dir = os.getenv("TRITON_CACHE_DIR", - "").strip() or default_cache_dir() - if self.cache_dir: - self.cache_dir = f"{self.cache_dir}_{os.getpid()}" - self.cache_dir = os.path.join(self.cache_dir, self.key) - self.lock_path = os.path.join(self.cache_dir, "lock") - os.makedirs(self.cache_dir, exist_ok=True) - else: - raise RuntimeError("Could not create or locate cache dir") + raise RuntimeError("Could not create or locate cache dir") diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index d6aa0d21f0981..3455036586a93 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -1,12 +1,11 @@ +from importlib.util import find_spec from vllm.logger import init_logger logger = init_logger(__name__) -try: - import triton - HAS_TRITON = True -except ImportError: +HAS_TRITON = find_spec("triton") is not None + +if not HAS_TRITON: logger.info("Triton not installed; certain GPU-related functions" " will be not be available.") - HAS_TRITON = False