Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Try to handle older versions of pytorch #9086

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/kernels/test_awq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os

import pytest
import torch

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401


@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
reason="AWQ is not supported on this GPU type.")
def test_awq_dequantize_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
qweight = torch.randint(-2000000000,
Expand All @@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck():
(qweight, scales, zeros, split_k_iters, thx, thy))


@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
reason="AWQ is not supported on this GPU type.")
def test_awq_gemm_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
Expand Down
4 changes: 4 additions & 0 deletions tests/kernels/test_awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
Expand All @@ -21,6 +22,9 @@
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.skipif(not (ops.supports_moe_ops
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
reason="Marlin is not supported on this GPU type.")
def test_fused_marlin_moe_awq(
m: int,
n: int,
Expand Down
53 changes: 32 additions & 21 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import contextlib
import functools
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch
import torch.library

import vllm.envs as envs
from vllm._core_ext import ScalarType
Expand All @@ -25,6 +26,16 @@
import vllm._moe_C # noqa: F401
supports_moe_ops = True

if TYPE_CHECKING:

def register_fake(fn):
return lambda name: fn
else:
try:
from torch.library import register_fake
except ImportError:
from torch.library import impl_abstract as register_fake


def hint_on_error(fn):

Expand Down Expand Up @@ -266,7 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "gptq_gemm"):

@torch.library.register_fake("_C::gptq_gemm")
@register_fake("_C::gptq_gemm")
def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
Expand Down Expand Up @@ -301,15 +312,15 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):

@torch.library.register_fake("_C::gptq_marlin_24_gemm")
@register_fake("_C::gptq_marlin_24_gemm")
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::gptq_marlin_gemm")
@register_fake("_C::gptq_marlin_gemm")
def _gptq_marlin_gemm_fake(a: torch.Tensor,
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
Expand All @@ -326,12 +337,12 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::ggml_dequantize")
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)

@torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
@register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
Expand All @@ -340,7 +351,7 @@ def _ggml_mul_mat_vec_a8_fake(
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)

@torch.library.register_fake("_C::ggml_mul_mat_a8")
@register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
Expand All @@ -350,7 +361,7 @@ def _ggml_mul_mat_a8_fake(
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)

@torch.library.register_fake("_C::marlin_qqq_gemm")
@register_fake("_C::marlin_qqq_gemm")
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
Expand All @@ -360,7 +371,7 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
dtype=torch.float16,
device=a.device)

@torch.library.register_fake("_C::marlin_gemm")
@register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
Expand All @@ -369,7 +380,7 @@ def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
dtype=torch.float16,
device=a.device)

@torch.library.register_fake("_C::awq_dequantize")
@register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
Expand All @@ -380,7 +391,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
dtype=scales.dtype,
device=scales.device)

@torch.library.register_fake("_C::awq_gemm")
@register_fake("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: int) -> torch.Tensor:
Expand All @@ -389,7 +400,7 @@ def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
dtype=input.dtype,
device=input.device).sum(0)

@torch.library.register_fake("_C::aqlm_gemm")
@register_fake("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: List[int],
Expand All @@ -405,7 +416,7 @@ def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes))

@torch.library.register_fake("_C::aqlm_dequant")
@register_fake("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: List[int]) -> torch.Tensor:
Expand All @@ -415,14 +426,14 @@ def _aqlm_dequant_fake(
dtype=codebooks.dtype,
device=codebooks.device)

@torch.library.register_fake("_C::fp8_marlin_gemm")
@register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)

@torch.library.register_fake("_C::machete_gemm")
@register_fake("_C::machete_gemm")
def machete_gemm_fake(
a: torch.Tensor,
# Should be the tensor returned by machete_prepack_B
Expand All @@ -440,13 +451,13 @@ def machete_gemm_fake(
n = b_q.size(1)
return torch.empty((m, n), device=a.device, dtype=a.dtype)

@torch.library.register_fake("_C::machete_prepack_B")
@register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)

@torch.library.register_fake("_C::causal_conv1d_fwd")
@register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
Expand All @@ -456,15 +467,15 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
silu_activation: bool) -> torch.Tensor:
return torch.empty_like(x)

@torch.library.register_fake("_C::causal_conv1d_update")
@register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(
x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor], silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor:
return torch.empty_like(x)

@torch.library.register_fake("_C::selective_scan_fwd")
@register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
Expand Down Expand Up @@ -639,7 +650,7 @@ def machete_prepack_B(b_q_weight: torch.Tensor,

if hasattr(torch.ops._C, "permute_cols"):

@torch.library.register_fake("_C::permute_cols")
@register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
Expand Down Expand Up @@ -837,7 +848,7 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,

if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):

@torch.library.register_fake("_moe_C::marlin_gemm_moe")
@register_fake("_moe_C::marlin_gemm_moe")
def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor,
sorted_ids: torch.Tensor,
topk_weights: torch.Tensor,
Expand Down
Loading