From 50f3ecf07d4566ee8be8a5c4378a940c7fbd9dc9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Oct 2024 20:51:20 -0400 Subject: [PATCH 1/5] Try to handle older versions of pytorch --- vllm/_custom_ops.py | 45 +++++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 24e008dc3802..d5137fcfc6d4 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -25,6 +25,11 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True +import torch.library +try: + import torch.library.register_fake +except ImportError: + from torch.library import impl_abstract as register_fake def hint_on_error(fn): @@ -266,7 +271,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_gemm"): - @torch.library.register_fake("_C::gptq_gemm") + @register_fake("_C::gptq_gemm") def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, @@ -301,7 +306,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): - @torch.library.register_fake("_C::gptq_marlin_24_gemm") + @register_fake("_C::gptq_marlin_24_gemm") def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, @@ -309,7 +314,7 @@ def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_n: int, size_k: int) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::gptq_marlin_gemm") + @register_fake("_C::gptq_marlin_gemm") def _gptq_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, @@ -326,12 +331,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, use_fp32_reduce: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::ggml_dequantize") + @register_fake("_C::ggml_dequantize") def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int, n: int) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::ggml_mul_mat_vec_a8") + @register_fake("_C::ggml_mul_mat_vec_a8") def _ggml_mul_mat_vec_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -340,7 +345,7 @@ def _ggml_mul_mat_vec_a8_fake( ) -> torch.Tensor: return torch.empty((1, row), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::ggml_mul_mat_a8") + @register_fake("_C::ggml_mul_mat_a8") def _ggml_mul_mat_a8_fake( W: torch.Tensor, X: torch.Tensor, @@ -350,7 +355,7 @@ def _ggml_mul_mat_a8_fake( batch = X.size(0) return torch.empty((batch, row), dtype=torch.float16, device=W.device) - @torch.library.register_fake("_C::marlin_qqq_gemm") + @register_fake("_C::marlin_qqq_gemm") def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, s_group: torch.Tensor, workspace: torch.Tensor, @@ -360,7 +365,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @torch.library.register_fake("_C::marlin_gemm") + @register_fake("_C::marlin_gemm") def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, size_n: int, @@ -369,7 +374,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, dtype=torch.float16, device=a.device) - @torch.library.register_fake("_C::awq_dequantize") + @register_fake("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: @@ -380,7 +385,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, dtype=scales.dtype, device=scales.device) - @torch.library.register_fake("_C::awq_gemm") + @register_fake("_C::awq_gemm") def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: @@ -389,7 +394,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, dtype=input.dtype, device=input.device).sum(0) - @torch.library.register_fake("_C::aqlm_gemm") + @register_fake("_C::aqlm_gemm") def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, codebook_partition_sizes: List[int], @@ -405,7 +410,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, output_sizes.append(-1) return flat_output.reshape(tuple(output_sizes)) - @torch.library.register_fake("_C::aqlm_dequant") + @register_fake("_C::aqlm_dequant") def _aqlm_dequant_fake( codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: List[int]) -> torch.Tensor: @@ -415,14 +420,14 @@ def _aqlm_dequant_fake( dtype=codebooks.dtype, device=codebooks.device) - @torch.library.register_fake("_C::fp8_marlin_gemm") + @register_fake("_C::fp8_marlin_gemm") def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @torch.library.register_fake("_C::machete_gemm") + @register_fake("_C::machete_gemm") def machete_gemm_fake( a: torch.Tensor, # Should be the tensor returned by machete_prepack_B @@ -440,13 +445,13 @@ def machete_gemm_fake( n = b_q.size(1) return torch.empty((m, n), device=a.device, dtype=a.dtype) - @torch.library.register_fake("_C::machete_prepack_B") + @register_fake("_C::machete_prepack_B") def machete_prepack_B_fake(b_q_weight: torch.Tensor, b_type: ScalarType) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - @torch.library.register_fake("_C::causal_conv1d_fwd") + @register_fake("_C::causal_conv1d_fwd") def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], conv_states: Optional[torch.Tensor], @@ -456,7 +461,7 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, silu_activation: bool) -> torch.Tensor: return torch.empty_like(x) - @torch.library.register_fake("_C::causal_conv1d_update") + @register_fake("_C::causal_conv1d_update") def causal_conv1d_update_fake( x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias_: Optional[torch.Tensor], silu_activation: bool, @@ -464,7 +469,7 @@ def causal_conv1d_update_fake( conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: return torch.empty_like(x) - @torch.library.register_fake("_C::selective_scan_fwd") + @register_fake("_C::selective_scan_fwd") def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], @@ -639,7 +644,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor, if hasattr(torch.ops._C, "permute_cols"): - @torch.library.register_fake("_C::permute_cols") + @register_fake("_C::permute_cols") def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) @@ -837,7 +842,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): - @torch.library.register_fake("_moe_C::marlin_gemm_moe") + @register_fake("_moe_C::marlin_gemm_moe") def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, sorted_ids: torch.Tensor, topk_weights: torch.Tensor, From 690fdfb6f07de9a1870b9c82c6443fb937708157 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Oct 2024 20:57:27 -0400 Subject: [PATCH 2/5] format --- vllm/_custom_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d5137fcfc6d4..a6de3b13019f 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -3,6 +3,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.library import vllm.envs as envs from vllm._core_ext import ScalarType @@ -25,12 +26,12 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True -import torch.library try: import torch.library.register_fake except ImportError: from torch.library import impl_abstract as register_fake + def hint_on_error(fn): @functools.wraps(fn) From 745adb3528383d9bd760ff3afe6c9726b4efe4a5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 4 Oct 2024 21:17:03 -0400 Subject: [PATCH 3/5] fix --- vllm/_custom_ops.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a6de3b13019f..3a23692285ef 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,6 +1,6 @@ import contextlib import functools -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch import torch.library @@ -26,10 +26,15 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True -try: - import torch.library.register_fake -except ImportError: - from torch.library import impl_abstract as register_fake +if TYPE_CHECKING: + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake def hint_on_error(fn): From c36564f130900bf046529f3556347f920f4bea35 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 8 Oct 2024 15:26:34 +0000 Subject: [PATCH 4/5] skip some tests if not supported --- tests/kernels/test_awq.py | 3 +++ tests/kernels/test_awq_marlin.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/tests/kernels/test_awq.py b/tests/kernels/test_awq.py index e421aca48af2..bff9c34da7b7 100644 --- a/tests/kernels/test_awq.py +++ b/tests/kernels/test_awq.py @@ -1,11 +1,13 @@ import os +import pytest import torch from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 +@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize")) def test_awq_dequantize_opcheck(): os.environ["VLLM_USE_TRITON_AWQ"] = "0" qweight = torch.randint(-2000000000, @@ -21,6 +23,7 @@ def test_awq_dequantize_opcheck(): (qweight, scales, zeros, split_k_iters, thx, thy)) +@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm")) def test_awq_gemm_opcheck(): os.environ["VLLM_USE_TRITON_AWQ"] = "0" input = torch.rand((2, 8192), device='cuda', dtype=torch.float16) diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 0738ea9b97ed..1c5221fc39f8 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -7,6 +7,7 @@ from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe, torch_moe_single) +from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk @@ -21,6 +22,9 @@ @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.skipif( + not (ops.supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe")) +) def test_fused_marlin_moe_awq( m: int, n: int, From 30292b948afb3f599f0e1602872d7f158c686c1b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 8 Oct 2024 18:54:35 +0000 Subject: [PATCH 5/5] add reasons to skipifs --- tests/kernels/test_awq.py | 6 ++++-- tests/kernels/test_awq_marlin.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_awq.py b/tests/kernels/test_awq.py index bff9c34da7b7..aa7a430850f9 100644 --- a/tests/kernels/test_awq.py +++ b/tests/kernels/test_awq.py @@ -7,7 +7,8 @@ from vllm import _custom_ops as ops # noqa: F401 -@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize")) +@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"), + reason="AWQ is not supported on this GPU type.") def test_awq_dequantize_opcheck(): os.environ["VLLM_USE_TRITON_AWQ"] = "0" qweight = torch.randint(-2000000000, @@ -23,7 +24,8 @@ def test_awq_dequantize_opcheck(): (qweight, scales, zeros, split_k_iters, thx, thy)) -@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm")) +@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"), + reason="AWQ is not supported on this GPU type.") def test_awq_gemm_opcheck(): os.environ["VLLM_USE_TRITON_AWQ"] = "0" input = torch.rand((2, 8192), device='cuda', dtype=torch.float16) diff --git a/tests/kernels/test_awq_marlin.py b/tests/kernels/test_awq_marlin.py index 1c5221fc39f8..0f0a2b24563f 100644 --- a/tests/kernels/test_awq_marlin.py +++ b/tests/kernels/test_awq_marlin.py @@ -22,9 +22,9 @@ @pytest.mark.parametrize("e", [8, 64]) @pytest.mark.parametrize("topk", [2, 6]) @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -@pytest.mark.skipif( - not (ops.supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe")) -) +@pytest.mark.skipif(not (ops.supports_moe_ops + and hasattr(torch.ops._moe_C, "marlin_gemm_moe")), + reason="Marlin is not supported on this GPU type.") def test_fused_marlin_moe_awq( m: int, n: int,