Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Fused MoE Marlin support for GPTQ #8217

Merged
merged 53 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
0abac6f
Enable 8-bit weights in Fused Marlin MoE
ElizaWszola Aug 30, 2024
fdf69c2
fix rocm
ElizaWszola Aug 30, 2024
4da163b
bad paste
ElizaWszola Aug 30, 2024
21d2337
add test case; fix imports for tests
dsikka Aug 30, 2024
080ab23
Merge branch 'main' into marlin-moe-8-bit
dsikka Aug 30, 2024
638777a
fix to adapt custom_routin_function
dsikka Aug 30, 2024
bd4b84d
Use select_experts to compute top_k tensors in fused moe
ElizaWszola Sep 2, 2024
bef6b53
bring back fused_moe_marlin -> fused_marlin_moe
ElizaWszola Sep 3, 2024
db1f07e
GPTQ Fused MoE class
ElizaWszola Sep 3, 2024
6753789
Add GPTQMarlinMoEMethod to gptq_marlin.py
ElizaWszola Sep 3, 2024
7df4014
Use FusedMoE layer for all loads
ElizaWszola Sep 4, 2024
befc52b
Merge branch 'main' into marlin-moe-8-bit
ElizaWszola Sep 4, 2024
c3dc249
Merge branch 'marlin-moe-8-bit' into gptq_fused_moe
ElizaWszola Sep 4, 2024
2fa03e5
Make sure that GPTQ runs through mixtral.py
ElizaWszola Sep 4, 2024
b45594c
remove large model
dsikka Sep 4, 2024
8a504d9
enforce float16A/scales for marlin moe
ElizaWszola Sep 4, 2024
effd2cd
Cleanup, comments
ElizaWszola Sep 4, 2024
689ea0a
Merge branch 'marlin-moe-8-bit' into gptq_fused_moe
ElizaWszola Sep 4, 2024
ec47561
cleanup
ElizaWszola Sep 4, 2024
9f97b3b
update/fix weight loading to support tp
dsikka Sep 5, 2024
b841ac4
remove 8-bit stuff for now
ElizaWszola Sep 6, 2024
a245032
Merge branch 'gptq_fused_moe' of https://github.com/neuralmagic/vllm …
ElizaWszola Sep 6, 2024
9d8a80c
fix; update large model testing cases
dsikka Sep 6, 2024
315e22f
add hack to support unfused mixtral pathway for int8
dsikka Sep 6, 2024
565cc43
fix install for tpu test
dsikka Sep 6, 2024
8886423
Move float16 typecast hack to gptq marlin moe method
ElizaWszola Sep 7, 2024
ab27497
Move output type conversion to gptq method as well
ElizaWszola Sep 7, 2024
847e860
Enable 8-bit weights in Fused Marlin MoE
ElizaWszola Aug 30, 2024
430a9cb
fix rocm
ElizaWszola Aug 30, 2024
48047aa
bad paste
ElizaWszola Aug 30, 2024
bfc4fae
add test case; fix imports for tests
dsikka Aug 30, 2024
c5a2f62
fix to adapt custom_routin_function
dsikka Aug 30, 2024
2b308c4
Use select_experts to compute top_k tensors in fused moe
ElizaWszola Sep 2, 2024
71256d4
bring back fused_moe_marlin -> fused_marlin_moe
ElizaWszola Sep 3, 2024
7aa844c
GPTQ Fused MoE class
ElizaWszola Sep 3, 2024
0f7bec3
Add GPTQMarlinMoEMethod to gptq_marlin.py
ElizaWszola Sep 3, 2024
cb0001e
Use FusedMoE layer for all loads
ElizaWszola Sep 4, 2024
33090a3
Make sure that GPTQ runs through mixtral.py
ElizaWszola Sep 4, 2024
d479837
enforce float16A/scales for marlin moe
ElizaWszola Sep 4, 2024
8baaec6
remove large model
dsikka Sep 4, 2024
8fbc181
Cleanup, comments
ElizaWszola Sep 4, 2024
839915f
cleanup
ElizaWszola Sep 4, 2024
a5bc626
remove 8-bit stuff for now
ElizaWszola Sep 6, 2024
c573fa1
update/fix weight loading to support tp
dsikka Sep 5, 2024
a991d82
fix; update large model testing cases
dsikka Sep 6, 2024
d57804d
add hack to support unfused mixtral pathway for int8
dsikka Sep 6, 2024
96fa486
fix install for tpu test
dsikka Sep 6, 2024
1faab90
Move float16 typecast hack to gptq marlin moe method
ElizaWszola Sep 7, 2024
970e06a
Move output type conversion to gptq method as well
ElizaWszola Sep 7, 2024
fd0a4f2
typo fix; fix comment
dsikka Sep 9, 2024
3ac9273
Merge branch 'gptq_fused_moe' of https://github.com/neuralmagic/vllm …
ElizaWszola Sep 9, 2024
d51a2f4
Clarify comment, change how we process bias
ElizaWszola Sep 9, 2024
12f05c5
Merge branch 'main' into gptq_fused_moe
dsikka Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,18 @@ steps:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt

- label: Weight Loading Multiple GPU Test - Large Models # optional
working_dir: "/vllm-workspace/tests"
num_gpus: 2
gpu: a100
optional: true
source_file_dependencies:
- vllm/
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt


##### multi gpus test #####
Expand Down
2 changes: 1 addition & 1 deletion csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1737,4 +1737,4 @@ torch::Tensor marlin_gemm_moe(
moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_n, sms, max_par, replicate_input, apply_weights);
return c;
}
}
2 changes: 1 addition & 1 deletion csrc/moe/marlin_moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ torch::Tensor marlin_gemm_moe(
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
bool replicate_input, bool apply_weights);
1 change: 0 additions & 1 deletion csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
"bool replicate_input, bool apply_weights) -> Tensor");

m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
#endif
}
Expand Down
221 changes: 217 additions & 4 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@

Run `pytest tests/kernels/test_moe.py`.
"""
from typing import List

import pytest
import torch
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize)
from vllm.model_executor.models.mixtral import MixtralMoE
from vllm.scalar_type import scalar_types


def torch_moe(a, w1, w2, score, topk):
Expand All @@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)


def torch_moe_single(a, w, score, topk):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
_, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.view(-1)
for i in range(w.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = a[mask] @ w[i].transpose(0, 1)
return (out.view(B, -1, w.shape[1])).sum(dim=1)


@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
Expand All @@ -43,11 +65,11 @@ def test_fused_moe(
topk: int,
dtype: torch.dtype,
):
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10

score = torch.randn((m, e), device='cuda', dtype=dtype)
score = torch.randn((m, e), device="cuda", dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk)
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
Expand Down Expand Up @@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype):
vllm_states,
rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype])


def stack_and_dev(tensors: List[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)


def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))


@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_fused_marlin_moe(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
):
torch.manual_seed(7)

if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
return
if group_size in (k, n):
return

quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)

w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []

for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref1_l.append(w_ref1)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)

w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l)
sort_indices1 = stack_and_dev(sort_indices1_l)

w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []

for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order,
test_perm)
w_ref2_l.append(w_ref2)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)

w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l)
sort_indices2 = stack_and_dev(sort_indices2_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, False)

triton_output = fused_moe(
a,
w_ref1.transpose(1, 2).contiguous(),
w_ref2.transpose(1, 2).contiguous(),
score,
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
score,
g_idx1,
g_idx2,
sort_indices1,
sort_indices2,
topk_weights,
topk_ids,
w1_scale=scales1,
w2_scale=scales2,
)

assert compute_max_diff(marlin_output, triton_output) < 4e-2


@pytest.mark.skip("This test is here for the sake of debugging, "
"don't run it in automated tests.")
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 1024, 512])
@pytest.mark.parametrize("e", [4, 8, 64])
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_marlin_moe_mmm(
m: int,
n: int,
k: int,
e: int,
topk: int,
group_size: int,
act_order: bool,
):
if topk > e:
return

# Filter act_order
if act_order:
if group_size == -1:
return
if group_size == k:
return

quant_type = scalar_types.uint4b8
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10

w_ref_l = []
qweights_l = []
scales_l = []
g_idx_l = []
sort_indices_l = []

for i in range(w.shape[0]):
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm)
w_ref_l.append(w_ref)
qweights_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)

w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweights_l).contiguous()
scales = stack_and_dev(scales_l)
g_idx = stack_and_dev(g_idx_l)
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

assert compute_max_diff(marlin_output, torch_output) < 1e-2
3 changes: 3 additions & 0 deletions tests/weight_loading/models-large.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
2 changes: 0 additions & 2 deletions tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
Expand Down
14 changes: 10 additions & 4 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON

__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
]

if HAS_TRITON:

from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_marlin_moe, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)

__all__ += [
"fused_marlin_moe",
"single_marlin_moe",
"fused_moe",
"fused_topk",
"fused_experts",
Expand Down
Loading
Loading