Skip to content

Commit

Permalink
[Bugfix] Try to handle older versions of pytorch (#9086)
Browse files Browse the repository at this point in the history
  • Loading branch information
bnellnm authored Oct 8, 2024
1 parent de24046 commit bd37b9f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 21 deletions.
5 changes: 5 additions & 0 deletions tests/kernels/test_awq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os

import pytest
import torch

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401


@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
reason="AWQ is not supported on this GPU type.")
def test_awq_dequantize_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
qweight = torch.randint(-2000000000,
Expand All @@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck():
(qweight, scales, zeros, split_k_iters, thx, thy))


@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
reason="AWQ is not supported on this GPU type.")
def test_awq_gemm_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
Expand Down
4 changes: 4 additions & 0 deletions tests/kernels/test_awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
Expand All @@ -21,6 +22,9 @@
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.skipif(not (ops.supports_moe_ops
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
reason="Marlin is not supported on this GPU type.")
def test_fused_marlin_moe_awq(
m: int,
n: int,
Expand Down
53 changes: 32 additions & 21 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import contextlib
import functools
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch
import torch.library

import vllm.envs as envs
from vllm._core_ext import ScalarType
Expand All @@ -25,6 +26,16 @@
import vllm._moe_C # noqa: F401
supports_moe_ops = True

if TYPE_CHECKING:

def register_fake(fn):
return lambda name: fn
else:
try:
from torch.library import register_fake
except ImportError:
from torch.library import impl_abstract as register_fake


def hint_on_error(fn):

Expand Down Expand Up @@ -266,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "gptq_gemm"):

@torch.library.register_fake("_C::gptq_gemm")
@register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
Expand Down Expand Up @@ -301,15 +312,15 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):

@torch.library.register_fake("_C::gptq_marlin_24_gemm")
@register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::gptq_marlin_gemm")
@register_fake("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
Expand All @@ -326,12 +337,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::ggml_dequantize")
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)

@torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
@register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
Expand All @@ -340,7 +351,7 @@ def _ggml_mul_mat_vec_a8_fake(
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)

@torch.library.register_fake("_C::ggml_mul_mat_a8")
@register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
Expand All @@ -350,7 +361,7 @@ def _ggml_mul_mat_a8_fake(
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)

@torch.library.register_fake("_C::marlin_qqq_gemm")
@register_fake("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
Expand All @@ -360,7 +371,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
dtype=torch.float16,
device=a.device)

@torch.library.register_fake("_C::marlin_gemm")
@register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
Expand All @@ -369,7 +380,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
dtype=torch.float16,
device=a.device)

@torch.library.register_fake("_C::awq_dequantize")
@register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
Expand All @@ -380,7 +391,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
dtype=scales.dtype,
device=scales.device)

@torch.library.register_fake("_C::awq_gemm")
@register_fake("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: int) -> torch.Tensor:
Expand All @@ -389,7 +400,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
dtype=input.dtype,
device=input.device).sum(0)

@torch.library.register_fake("_C::aqlm_gemm")
@register_fake("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
Expand All @@ -405,7 +416,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes))

@torch.library.register_fake("_C::aqlm_dequant")
@register_fake("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
Expand All @@ -415,14 +426,14 @@ def _aqlm_dequant_fake(
dtype=codebooks.dtype,
device=codebooks.device)

@torch.library.register_fake("_C::fp8_marlin_gemm")
@register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)

@torch.library.register_fake("_C::machete_gemm")
@register_fake("_C::machete_gemm")
def machete_gemm_fake(
a: torch.Tensor,
# Should be the tensor returned by machete_prepack_B
Expand All @@ -440,13 +451,13 @@ def machete_gemm_fake(
n = b_q.size(1)
return torch.empty((m, n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::machete_prepack_B")
@register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)

@torch.library.register_fake("_C::causal_conv1d_fwd")
@register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
Expand All @@ -456,15 +467,15 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)

@torch.library.register_fake("_C::causal_conv1d_update")
@register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x)

@torch.library.register_fake("_C::selective_scan_fwd")
@register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
Expand Down Expand Up @@ -639,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "permute_cols"):

@torch.library.register_fake("_C::permute_cols")
@register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
Expand Down Expand Up @@ -837,7 +848,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,

if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

@torch.library.register_fake("_moe_C::marlin_gemm_moe")
@register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
Expand Down

0 comments on commit bd37b9f

Please sign in to comment.