From 7842c0d3867c60755a83ba5fdb11ed131dce435c Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:35:52 -0400 Subject: [PATCH] [Kernel] Fullgraph and opcheck tests (#8479) --- .buildkite/test-pipeline.yaml | 19 +++- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 2 +- csrc/torch_bindings.cpp | 4 +- tests/compile/test_full_graph.py | 45 ++------ tests/compile/test_full_graph_multi_gpu.py | 22 ++++ tests/compile/test_full_graph_smoke.py | 13 +++ tests/compile/utils.py | 104 ++++++++++++++++++ tests/conftest.py | 6 + tests/kernels/test_aqlm.py | 37 +++++++ tests/kernels/test_attention.py | 9 +- tests/kernels/test_awq.py | 38 +++++++ tests/kernels/test_causal_conv1d.py | 74 ++++++++++++- tests/kernels/test_cutlass.py | 10 ++ tests/kernels/test_flash_attn.py | 61 +++++----- tests/kernels/test_fp8_quant.py | 29 +++++ tests/kernels/test_ggml.py | 22 ++++ tests/kernels/test_gptq.py | 29 +++++ tests/kernels/test_mamba_ssm.py | 66 +++++++++++ tests/kernels/test_marlin_gemm.py | 15 +++ tests/kernels/test_moe.py | 60 +++++++++- tests/kernels/test_rotary_embedding.py | 62 +++++++++++ tests/kernels/test_utils.py | 24 ++++ tests/kernels/utils.py | 43 +++++++- vllm/_custom_ops.py | 61 +++++----- .../layers/mamba/ops/mamba_ssm.py | 4 +- .../layers/quantization/gptq.py | 1 + 26 files changed, 744 insertions(+), 116 deletions(-) create mode 100644 tests/compile/test_full_graph_multi_gpu.py create mode 100644 tests/compile/test_full_graph_smoke.py create mode 100644 tests/compile/utils.py create mode 100644 tests/kernels/test_aqlm.py create mode 100644 tests/kernels/test_awq.py create mode 100644 tests/kernels/test_ggml.py create mode 100644 tests/kernels/test_gptq.py create mode 100644 tests/kernels/test_rotary_embedding.py create mode 100644 tests/kernels/test_utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 854147fbf71f6..b6fb826d90ce7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -70,7 +70,7 @@ steps: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - + - label: Core Test # 10min mirror_hardwares: [amd] fast_check: true @@ -210,6 +210,21 @@ steps: command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 +- label: "PyTorch Fullgraph Smoke Test" + fast_check: true + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_full_graph_smoke.py + +- label: "PyTorch Fullgraph Test" + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_full_graph.py + - label: Kernels Test %N # 30min each mirror_hardwares: [amd] source_file_dependencies: @@ -355,7 +370,7 @@ steps: - tests/distributed/ - vllm/compilation commands: - - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_full_graph_multi_gpu.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index df968dda92adc..d7829f5d583d4 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { selective_scan_fwd_cuda(params, stream); }); - std::vector result = {out, x.value()}; + std::vector result = {out}; if (has_z) { result.push_back(out_z); } return result; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4b374af5ae24e..b6ba1b2a26e10 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -275,7 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor! A, Tensor! B, Tensor! C," "Tensor? D_, Tensor? z_, Tensor? delta_bias_," "bool delta_softplus," - "Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]"); + "Tensor? index_, Tensor!? x) -> Tensor[]"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.def( @@ -292,7 +292,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? bias_," "Tensor? seq_idx_," "Tensor? initial_states_," - "Tensor? final_states_out_," + "Tensor!? final_states_out_," "bool silu_activation) -> Tensor"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 2e309aaa58d48..5dd65ad7236f9 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,42 +1,13 @@ -import os - import pytest -from vllm.utils import cuda_device_count_stateless - -from ..utils import fork_new_process_for_each_test - - -@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -@pytest.mark.parametrize("tp_size", [1, 2]) -@fork_new_process_for_each_test -def test_full_graph(model, tp_size): - - # Skip the test if there are not enough CUDA devices. - if cuda_device_count_stateless() < tp_size: - pytest.skip("Not enough CUDA devices for the test.") - - # make sure these models can be captured in full graph mode - if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: - os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" +from vllm.compilation.backends import vllm_backend - from vllm import LLM, SamplingParams - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params = SamplingParams(temperature=0) - llm = LLM(model=model, - enforce_eager=True, - tensor_parallel_size=tp_size, - disable_custom_all_reduce=True) +from .utils import TEST_MODELS, check_full_graph_support - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +@pytest.mark.parametrize("model_info", TEST_MODELS) +@pytest.mark.parametrize("backend", ["eager", vllm_backend]) +def test_full_graph(model_info, backend): + model = model_info[0] + model_kwargs = model_info[1] + check_full_graph_support(model, model_kwargs, backend, tp_size=1) diff --git a/tests/compile/test_full_graph_multi_gpu.py b/tests/compile/test_full_graph_multi_gpu.py new file mode 100644 index 0000000000000..e9883d5254e72 --- /dev/null +++ b/tests/compile/test_full_graph_multi_gpu.py @@ -0,0 +1,22 @@ +import pytest + +from vllm.compilation.backends import vllm_backend +from vllm.utils import cuda_device_count_stateless + +from ..utils import fork_new_process_for_each_test +from .utils import TEST_MODELS_SMOKE, check_full_graph_support + + +@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("backend", ["eager", vllm_backend]) +@fork_new_process_for_each_test +def test_full_graph_multi_gpu(model_info, tp_size, backend): + model = model_info[0] + model_kwargs = model_info[1] + + # Skip the test if there are not enough CUDA devices. + if cuda_device_count_stateless() < tp_size: + pytest.skip("Not enough CUDA devices for the test.") + + check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size) diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_full_graph_smoke.py new file mode 100644 index 0000000000000..0c5a95b4ead4c --- /dev/null +++ b/tests/compile/test_full_graph_smoke.py @@ -0,0 +1,13 @@ +import pytest + +from vllm.compilation.backends import vllm_backend + +from .utils import TEST_MODELS_SMOKE, check_full_graph_support + + +@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) +@pytest.mark.parametrize("backend", ["eager", vllm_backend]) +def test_full_graph(model_info, backend): + model = model_info[0] + model_kwargs = model_info[1] + check_full_graph_support(model, model_kwargs, backend, tp_size=1) diff --git a/tests/compile/utils.py b/tests/compile/utils.py new file mode 100644 index 0000000000000..2d06a0946d911 --- /dev/null +++ b/tests/compile/utils.py @@ -0,0 +1,104 @@ +import os + +import torch + +from tests.quantization.utils import is_quant_method_supported +from vllm import LLM, SamplingParams +from vllm.plugins import set_torch_compile_backend +from vllm.utils import is_hip + +TEST_MODELS_SMOKE = [ + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { + "quantization": "compressed-tensors" + }), + ("meta-llama/Meta-Llama-3-8B", {}), +] + +TEST_MODELS = [ + ("facebook/opt-125m", {}), + ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { + "dtype": torch.float16, + "quantization": "compressed-tensors" + }), + ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", { + "dtype": torch.float16, + "quantization": "fp8" + }), + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { + "quantization": "compressed-tensors" + }), + ("meta-llama/Meta-Llama-3-8B", {}), +] + +# TODO: enable in pytorch 2.5 +if False and is_quant_method_supported("aqlm"): # noqa: SIM223 + TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { + "quantization": "aqlm" + })) + +# TODO: enable in pytorch 2.5 +if False and is_quant_method_supported("gguf"): # noqa: SIM223 + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", { + "quantization": "gguf" + })) + +if is_quant_method_supported("gptq"): + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", { + "quantization": "gptq" + })) + +if is_quant_method_supported("gptq_marlin"): + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", { + "quantization": "gptq_marlin" + })) + +if is_quant_method_supported("gptq_marlin_24"): + TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", { + "quantization": "gptq_marlin_24" + })) + +if is_quant_method_supported("marlin"): + TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", { + "quantization": "marlin" + })) + +if not is_hip() and is_quant_method_supported("awq"): + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { + "quantization": "AWQ" + })) + + +def check_full_graph_support(model, model_kwargs, backend, tp_size=1): + # make sure these models can be captured in full graph mode + if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: + os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" + os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" + + # Inductor doesn't support fp8/gptq_marlin_24 yet. + quantization = model_kwargs.get("quantization") + if (quantization == "fp8" or quantization == "gptq_marlin" + or quantization == "gptq_marlin_24") and backend != "eager": + return + + set_torch_compile_backend(backend) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tp_size, + disable_custom_all_reduce=True, + **model_kwargs) + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/conftest.py b/tests/conftest.py index dcd9afdae3c14..354862e3579ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,6 +169,12 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): cleanup() +@pytest.fixture(autouse=True) +def dynamo_reset(): + yield + torch._dynamo.reset() + + @pytest.fixture def example_prompts() -> List[str]: prompts = [] diff --git a/tests/kernels/test_aqlm.py b/tests/kernels/test_aqlm.py new file mode 100644 index 0000000000000..860fb66b17354 --- /dev/null +++ b/tests/kernels/test_aqlm.py @@ -0,0 +1,37 @@ +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +def test_aqlm_dequant_opcheck(): + codes = torch.randint(-32768, + 32767, (22016, 512, 1), + device='cuda', + dtype=torch.int16) + codebooks = torch.rand((2, 65536, 1, 8), + device='cuda', + dtype=torch.float16) + codebook_partition_sizes = [11008, 11008] + + opcheck(torch.ops._C.aqlm_dequant, + (codes, codebooks, codebook_partition_sizes)) + + +def test_aqlm_gemm_opcheck(): + input = torch.rand((4, 4096), device='cuda', dtype=torch.float16) + codes = torch.randint(-32768, + 32767, (12288, 512, 1), + device='cuda', + dtype=torch.int16) + codebooks = torch.rand((3, 65536, 1, 8), + device='cuda', + dtype=torch.float16) + scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16) + codebook_partition_sizes = [4096, 4096, 4096] + bias = None + + opcheck(torch.ops._C.aqlm_gemm, + (input, codes, codebooks, scales, codebook_partition_sizes, None)) + opcheck(torch.ops._C.aqlm_gemm, + (input, codes, codebooks, scales, codebook_partition_sizes, bias)) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index ecab512cba16f..52f1ecd176963 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -205,7 +205,8 @@ def test_paged_attention( (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0])) + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) @@ -246,7 +247,8 @@ def test_paged_attention( key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0])) + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -274,7 +276,8 @@ def test_paged_attention( key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: raise AssertionError(f"Unknown version: {version}") diff --git a/tests/kernels/test_awq.py b/tests/kernels/test_awq.py new file mode 100644 index 0000000000000..e421aca48af2c --- /dev/null +++ b/tests/kernels/test_awq.py @@ -0,0 +1,38 @@ +import os + +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +def test_awq_dequantize_opcheck(): + os.environ["VLLM_USE_TRITON_AWQ"] = "0" + qweight = torch.randint(-2000000000, + 2000000000, (8192, 256), + device='cuda', + dtype=torch.int32) + scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16) + zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32) + split_k_iters = 0 + thx = 0 + thy = 0 + opcheck(torch.ops._C.awq_dequantize, + (qweight, scales, zeros, split_k_iters, thx, thy)) + + +def test_awq_gemm_opcheck(): + os.environ["VLLM_USE_TRITON_AWQ"] = "0" + input = torch.rand((2, 8192), device='cuda', dtype=torch.float16) + qweight = torch.randint(-2000000000, + 2000000000, (8192, 256), + device='cuda', + dtype=torch.int32) + scales = torch.randint(-2000000000, + 2000000000, (64, 256), + device='cuda', + dtype=torch.int32) + qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16) + split_k_iters = 8 + opcheck(torch.ops._C.awq_gemm, + (input, qweight, qzeros, scales, split_k_iters)) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 043c4923bd660..744e445fe6673 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -5,6 +5,8 @@ import torch.nn.functional as F from einops import rearrange +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.utils import seed_everything @@ -84,6 +86,64 @@ def causal_conv1d_update_ref(x: torch.Tensor, return (out if activation is None else F.silu(out)).to(dtype=dtype_in) +def causal_conv1d_opcheck_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + seq_idx: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out=None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert (initial_states is + None), "initial_states must be None if seq_idx is not None" + assert (not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and (initial_states.stride(2) != 1 + and initial_states.stride(1) != 1): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert (final_states_out.stride(2) == 1 + or final_states_out.stride(1) == 1) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty(batch, + width - 1, + dim, + device=x.device, + dtype=x.dtype).transpose(1, 2) + else: + final_states_out = None + + opcheck(torch.ops._C.causal_conv1d_fwd, + (x, weight, bias, seq_idx, initial_states, final_states_out, + activation in ["silu", "swish"])) + + @pytest.mark.parametrize("return_final_states", [False, True]) @pytest.mark.parametrize("has_initial_states", [False, True]) @pytest.mark.parametrize("channel_last", [False, True]) @@ -149,6 +209,14 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation) + + causal_conv1d_opcheck_fn(x_ref, + weight_ref, + bias_ref, + initial_states=initial_states_ref, + return_final_states=return_final_states, + activation=activation) + if return_final_states: assert final_states is not None and final_states_ref is not None assert torch.allclose(final_states, @@ -205,6 +273,10 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + opcheck( + torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation in ["silu", "swish"], None)) + @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @@ -258,7 +330,5 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, bias, activation=activation) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index cc4ca2e91e76f..993e67e827ea0 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -15,6 +15,9 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +capability = current_platform.get_device_capability() +capability = capability[0] * 10 + capability[1] + def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) @@ -74,6 +77,9 @@ def cutlass_fp8_gemm_helper(m: int, torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) + opcheck(torch.ops._C.cutlass_scaled_mm, + (out, a, b, scale_a, scale_b, bias)) + def cutlass_int8_gemm_helper(m: int, n: int, @@ -425,3 +431,7 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): baseline = torch.mm(scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + +def test_cutlass_support_opcheck(): + opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 8e960d098c408..71f61c19dd951 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,6 +4,7 @@ import torch import vllm.attention.backends.flash_attn # noqa: F401 +from tests.kernels.utils import opcheck from vllm.utils import seed_everything NUM_HEADS = [(4, 4), (8, 2), (16, 2)] @@ -127,19 +128,19 @@ def test_flash_attn_with_paged_kv( else: test_utils = ["test_faketensor"] - torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache, - args=tuple(), - kwargs=dict( - decode_query=query.unsqueeze(1), - key_cache=key_cache, - value_cache=value_cache, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - softcap=soft_cap if soft_cap is not None else 0, - ), - test_utils=test_utils) + opcheck(torch.ops.vllm.flash_attn_with_kvcache, + args=tuple(), + kwargs=dict( + decode_query=query.unsqueeze(1), + key_cache=key_cache, + value_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + softcap=soft_cap if soft_cap is not None else 0, + ), + test_utils=test_utils) ref_output = ref_paged_attn( query=query, @@ -232,23 +233,23 @@ def test_varlen_with_paged_kv( else: test_utils = ["test_faketensor"] - torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func, - args=tuple(), - kwargs=dict( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=window_size, - block_table=block_tables, - softcap=soft_cap if soft_cap is not None else 0, - ), - test_utils=test_utils) + opcheck(torch.ops.vllm.flash_attn_varlen_func, + args=tuple(), + kwargs=dict( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, + ), + test_utils=test_utils) ref_output = ref_paged_attn( query=query, diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 49f5ce53aab54..c18f5f468dc5a 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -5,6 +5,7 @@ from tests.kernels.quant_utils import (FP8_DTYPE, ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_token_quant) +from tests.kernels.utils import opcheck from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -16,6 +17,26 @@ SEEDS = [0] +def opcheck_fp8_quant(output, + input, + scale=None, + scale_ub=None, + use_per_token_if_dynamic=False): + if scale is not None: + opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale)) + elif use_per_token_if_dynamic: + scale = torch.empty((input.shape[0], 1), + device=input.device, + dtype=torch.float32) + opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant, + (output, input, scale, scale_ub)) + else: + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale)) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -41,6 +62,12 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, torch.testing.assert_close(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) + opcheck_fp8_quant(ops_out, + x, + None, + scale_ub, + use_per_token_if_dynamic=True) + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @@ -60,6 +87,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, torch.testing.assert_close(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) + opcheck_fp8_quant(ops_out, x) + # Regression test for a case with large activations where an int32 index cannot # represent the number of elements. diff --git a/tests/kernels/test_ggml.py b/tests/kernels/test_ggml.py new file mode 100644 index 0000000000000..dddb285bf26ec --- /dev/null +++ b/tests/kernels/test_ggml.py @@ -0,0 +1,22 @@ +import gguf +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +@pytest.mark.parametrize("quant_type", [12]) +def test_ggml_opcheck(quant_type): + block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type] + shape = [256, 1152] + qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) + m = qweight.shape[0] + n = qweight.shape[1] // type_size * block_size + opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n)) + + x = torch.rand((m, 512), device='cuda', dtype=torch.float16) + opcheck(torch.ops._C.ggml_mul_mat_a8, + (qweight, x, quant_type, qweight.shape[0])) + opcheck(torch.ops._C.ggml_mul_mat_vec_a8, + (qweight, x, quant_type, qweight.shape[0])) diff --git a/tests/kernels/test_gptq.py b/tests/kernels/test_gptq.py new file mode 100644 index 0000000000000..c1ca6f1f5191b --- /dev/null +++ b/tests/kernels/test_gptq.py @@ -0,0 +1,29 @@ +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +def test_gptq_shuffle_opcheck(): + weight = torch.randint(-2000000, + 2000000, (1792, 4096), + device='cuda', + dtype=torch.int32) + perm = torch.empty((0, ), device='cuda', dtype=torch.int32) + bit = 4 + opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit)) + + +def test_gptq_gemm_opcheck(): + a = torch.rand((240, 4096), device='cuda', dtype=torch.float16) + weight = torch.randint(-2000000, + 2000000, (512, 6144), + device='cuda', + dtype=torch.int32) + zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32) + scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16) + idx = torch.empty((0, ), device='cuda', dtype=torch.int32) + use_exllama = True + bit = 4 + opcheck(torch.ops._C.gptq_gemm, + (a, weight, zeros, scales, idx, use_exllama, bit)) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 366475222a68e..5a6149562e886 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -3,6 +3,8 @@ import torch.nn.functional as F from einops import rearrange, repeat +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.utils import seed_everything @@ -161,6 +163,59 @@ def selective_scan_ref(u, return out if not return_last_state else (out, last_state) +def selective_scan_opcheck_fn(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = B.unsqueeze(1) + if C.dim() == 3: + C = C.unsqueeze(1) + n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) + x = torch.zeros(( + u.shape[0], + u.shape[1], + n_chunks, + int(A.shape[1] * 2), + ), + device=u.device, + dtype=torch.float32, + requires_grad=False) + x[:, :, 0, 0::2] = 1 + if prev_state is not None: + x[:, :, 0, 1::2].copy_(prev_state) + + # Disable test_autograd_registration for now as it seems to trigger + # a bogus error. + opcheck(torch.ops._C.selective_scan_fwd, + (u, delta, A, B, C, D, z, delta_bias, delta_softplus, + position_indices, x), + test_utils=["test_schema", "test_faketensor"]) + + @pytest.mark.parametrize('wtype', [torch.float32]) @pytest.mark.parametrize('itype', [torch.float32]) @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) @@ -274,6 +329,17 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, assert state is not None and state_ref is not None assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + selective_scan_opcheck_fn(u, + delta, + A, + B, + C, + D, + z=z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state) + @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 721d3a6a819ac..a9bb72156c39e 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -501,3 +501,18 @@ def test_marlin_qqq_gemm( max_diff = compute_max_diff(output, output_ref) assert max_diff < 0.04 + + +def test_marlin_gemm_opcheck(): + size_m = 2048 + size_n = 4096 + size_k = 4096 + a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16) + w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32) + s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16) + wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL).scratch + x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) + y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) + torch.testing.assert_close(x, y) + opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k)) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index b1f0516dfa0b3..c6ddcc8ce79f5 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -9,11 +9,14 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, single_marlin_moe) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE @@ -247,6 +250,35 @@ def test_fused_marlin_moe( assert compute_max_diff(marlin_output, triton_output) < 4e-2 + if ops.supports_moe_ops: + token_expert_indicies = torch.empty(m, + topk, + dtype=torch.int32, + device=a.device) + + opcheck(torch.ops._moe_C.topk_softmax, ( + topk_weights, + topk_ids, + token_expert_indicies, + score.float(), + )) + + block_size_m = 4 + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, + e) + + max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + opcheck(torch.ops._moe_C.marlin_gemm_moe, + (a, qweight1, sorted_token_ids, topk_weights, topk_ids, + scales1, g_idx1, sort_indices1, workspace, quant_type, m, + 2 * n, k, True, e, topk, block_size_m, True, False)) + @pytest.mark.skip("This test is here for the sake of debugging, " "don't run it in automated tests.") @@ -319,3 +351,29 @@ def test_single_marlin_moe_multiply( torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 + + +def test_moe_align_block_size_opcheck(): + num_experts = 4 + block_size = 4 + topk_ids = torch.randint(0, + num_experts, (3, 4), + dtype=torch.int32, + device='cuda') + + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + + opcheck(torch.ops._C.moe_align_block_size, + (topk_ids, num_experts, block_size, sorted_ids, expert_ids, + num_tokens_post_pad)) diff --git a/tests/kernels/test_rotary_embedding.py b/tests/kernels/test_rotary_embedding.py new file mode 100644 index 0000000000000..da879406b3936 --- /dev/null +++ b/tests/kernels/test_rotary_embedding.py @@ -0,0 +1,62 @@ +""" +Tests for miscellaneous utilities +""" + +from typing import Optional + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + + +def rotary_embedding_opcheck(rot, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None): + cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) + + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + opcheck(torch.ops._C.batched_rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, + rot.is_neox_style, rot.rotary_dim, offsets)) + else: + opcheck(torch.ops._C.rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, + rot.is_neox_style)) + + +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("max_position", [11, 4096, 32768]) +@pytest.mark.parametrize("is_neox_style", [True, False]) +@pytest.mark.parametrize("rotary_dim", [32]) +@pytest.mark.parametrize("head_size", [32, 108]) +@pytest.mark.parametrize("seq_len", [11, 1024]) +def test_rotary_embedding_opcheck(dist_init, device, max_position, + is_neox_style, rotary_dim, head_size, + seq_len): + batch_size = 1 + base = 0 + num_heads = 7 + rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, torch.float32) + + positions = torch.randint(0, + max_position, (batch_size, seq_len), + device=device) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=torch.float32, + device=device) + key = torch.randn_like(query) + + rotary_embedding_opcheck(rot, positions, query, key) + offsets = torch.zeros(batch_size * seq_len, + device=device, + dtype=torch.long) + rotary_embedding_opcheck(rot, positions, query, key, offsets) diff --git a/tests/kernels/test_utils.py b/tests/kernels/test_utils.py new file mode 100644 index 0000000000000..7e5126a76f88b --- /dev/null +++ b/tests/kernels/test_utils.py @@ -0,0 +1,24 @@ +""" +Tests for miscellaneous utilities +""" + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.platforms import current_platform + + +def test_convert_fp8_opcheck(): + data = torch.randn((256, 256), dtype=torch.float32, device="cuda") + result = torch.empty_like(data, dtype=torch.float8_e4m3fn) + opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8")) + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="Only supported for CUDA") +def test_cuda_utils_opcheck(): + opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0)) + opcheck( + torch.ops._C_cuda_utils. + get_max_shared_memory_per_block_device_attribute, (0, )) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 5746932c30a45..08004efe9e2f8 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,12 +2,14 @@ import itertools import random +import unittest from numbers import Number from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union) import pytest import torch +from torch._prims_common import TensorLikeType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, @@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, output_under_test.view_as(ideal_output)) +# Copied/modified from torch._refs.__init__.py +def fp8_allclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + """ + Reference implementation of torch.allclose + """ + torch._refs._check_close_args(name="torch.allclose", + a=a, + b=b, + rtol=rtol, + atol=atol) + + return bool( + torch.all( + torch.isclose(a.double(), + b.double(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan)).item()) + + +# A special version of op check that has a restricted default set of test_utils +# and a patched version of allclose that supports fp8 types. def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, torch._library.custom_ops.CustomOpDef], args: Tuple[Any, ...], @@ -954,9 +984,10 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, raise_exception: bool = True, cond: bool = True) -> Dict[str, str]: - return torch.library.opcheck( - op, - args, - kwargs, - test_utils=test_utils, - raise_exception=raise_exception) if cond else {} + with unittest.mock.patch('torch.allclose', new=fp8_allclose): + return torch.library.opcheck( + op, + args, + kwargs, + test_utils=test_utils, + raise_exception=raise_exception) if cond else {} diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a71bafc974adf..4d71381184de5 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -20,8 +20,10 @@ if current_platform.is_rocm(): import vllm._rocm_C # noqa: F401 +supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 + supports_moe_ops = True def hint_on_error(fn): @@ -253,9 +255,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_g_idx, use_exllama, bit) -# TODO: has to be a better way to do this -try: - torch.ops._C.gptq_gemm # noqa B018 +if hasattr(torch.ops._C, "gptq_gemm"): @torch.library.register_fake("_C::gptq_gemm") def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @@ -265,8 +265,6 @@ def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, return torch.empty((a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device) -except Exception: - pass def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, @@ -292,9 +290,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) -# TODO: has to be a better way to do this -try: - torch.ops._C.gptq_marlin_24_gemm # noqa B018 +if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @torch.library.register_fake("_C::gptq_marlin_24_gemm") def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @@ -420,8 +416,8 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @torch.library.register_fake("_C::machete_gemm") def machete_gemm_fake( a: torch.Tensor, - b_q: torch. - Tensor, # Should be the tensor returned by machete_prepack_B + # Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, b_type: ScalarType, b_scales: Optional[torch.Tensor] = None, b_zeros: Optional[torch.Tensor] = None, @@ -451,10 +447,10 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, return torch.empty_like(x) @torch.library.register_fake("_C::causal_conv1d_update") - def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, - bias_: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: + def causal_conv1d_update_fake( + x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], silu_activation: bool, + conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: return torch.empty_like(x) @torch.library.register_fake("_C::selective_scan_fwd") @@ -465,20 +461,11 @@ def selective_scan_fwd_fake( delta_softplus: bool, index_: Optional[torch.Tensor], x: Optional[torch.Tensor]) -> List[torch.Tensor]: a = torch.empty_like(u) - if x is not None: - b = x - else: - b = torch.empty((u.size(0), u.size(1), A.size(1)), - dtype=u.dtype, - device=u.device) if z_ is not None: c = torch.empty_like(z_) - return [a, b, c] + return [a, c] else: - return [a, b] - -except Exception: - pass + return [a] # cutlass @@ -626,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor, return torch.ops._C.machete_prepack_B(b_q_weight, b_type) -# TODO: has to be a better way to do this -try: - torch.ops._C.permute_cols # noqa B018 +if hasattr(torch.ops._C, "permute_cols"): @torch.library.register_fake("_C::permute_cols") def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.empty_like(a) -except Exception: - pass def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: @@ -828,6 +811,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies, gating_output) +if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): + + @torch.library.register_fake("_moe_C::marlin_gemm_moe") + def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, + sorted_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, b_scales: torch.Tensor, + g_idx: torch.Tensor, perm: torch.Tensor, + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int, + is_k_full: bool, num_experts: int, topk: int, + moe_block_size: int, replicate_input: bool, + apply_weights: bool) -> torch.Tensor: + return torch.empty((size_m, topk, size_n), + dtype=a.dtype, + device=a.device) + + def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index a0bed07ac6193..5fe451b2f1318 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -361,8 +361,8 @@ def selective_scan_fn(u, x[:, :, 0, 0::2] = 1 if prev_state is not None: x[:, :, 0, 1::2].copy_(prev_state) - out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, - delta_softplus, position_indices, x) + out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, + delta_softplus, position_indices, x) last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if z is None: return out if not return_last_state else (out, last_state) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index c067a76405df6..1cfadb4f42ca8 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -217,6 +217,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass