Skip to content

Commit

Permalink
skip some tests if not supported
Browse files Browse the repository at this point in the history
  • Loading branch information
bnellnm committed Oct 8, 2024
1 parent 745adb3 commit c36564f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/kernels/test_awq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os

import pytest
import torch

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401


@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"))
def test_awq_dequantize_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
qweight = torch.randint(-2000000000,
Expand All @@ -21,6 +23,7 @@ def test_awq_dequantize_opcheck():
(qweight, scales, zeros, split_k_iters, thx, thy))


@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"))
def test_awq_gemm_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
Expand Down
4 changes: 4 additions & 0 deletions tests/kernels/test_awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
Expand All @@ -21,6 +22,9 @@
@pytest.mark.parametrize("e", [8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.skipif(
not (ops.supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"))
)
def test_fused_marlin_moe_awq(
m: int,
n: int,
Expand Down

0 comments on commit c36564f

Please sign in to comment.