Skip to content

Commit 56cd52f

Browse files
avshalommanprashantgupta24
authored andcommitted
[Bugfix] adding chunking mechanism to fused_moe to handle large inputs (vllm-project#6029)
1 parent 58d99ec commit 56cd52f

File tree

3 files changed

+74
-48
lines changed

3 files changed

+74
-48
lines changed

tests/kernels/test_moe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk):
2929
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
3030

3131

32-
@pytest.mark.parametrize("m", [512, 222, 33, 1])
32+
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
3333
@pytest.mark.parametrize("n", [2048, 256, 1024])
3434
@pytest.mark.parametrize("k", [128, 511, 1024])
3535
@pytest.mark.parametrize("e", [8, 64])

vllm/envs.py

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
3333
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
3434
VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/"
35+
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
3536
VLLM_USE_RAY_COMPILED_DAG: bool = False
3637
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
3738
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
@@ -248,6 +249,8 @@
248249
# Only used for XLA devices such as TPUs.
249250
"VLLM_XLA_CACHE_PATH":
250251
lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"),
252+
"VLLM_FUSED_MOE_CHUNK_SIZE":
253+
lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")),
251254
}
252255

253256
# end-env-vars-definition

vllm/model_executor/layers/fused_moe/fused_moe.py

+70-47
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import triton
99
import triton.language as tl
1010

11+
import vllm.envs as envs
1112
from vllm import _custom_ops as ops
1213
from vllm.logger import init_logger
1314

@@ -420,13 +421,12 @@ def fused_experts(hidden_states: torch.Tensor,
420421
torch.float32, torch.float16, torch.bfloat16
421422
]
422423

423-
M, _ = hidden_states.shape
424+
num_tokens, _ = hidden_states.shape
424425
E, N, _ = w1.shape
425-
426-
if M > 65536:
427-
# https://github.com/vllm-project/vllm/issues/5938
428-
raise ValueError("MoE kernel does not support more than 65536 tokens, "
429-
f"but got {M}")
426+
# We execute the fused_moe kernel in chunks to circumvent this issue:
427+
# https://github.com/vllm-project/vllm/issues/5938
428+
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
429+
M = min(num_tokens, CHUNK_SIZE)
430430

431431
if override_config:
432432
config = override_config
@@ -455,51 +455,74 @@ def fused_experts(hidden_states: torch.Tensor,
455455
device=hidden_states.device,
456456
dtype=hidden_states.dtype)
457457

458-
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
459-
topk_ids, config['BLOCK_SIZE_M'], E)
460458
compute_type = (tl.bfloat16
461459
if hidden_states.dtype == torch.bfloat16 else tl.float16)
462460

463-
invoke_fused_moe_kernel(hidden_states,
464-
w1,
465-
intermediate_cache1,
466-
a1_scale,
467-
w1_scale,
468-
topk_weights,
469-
topk_ids,
470-
sorted_token_ids,
471-
expert_ids,
472-
num_tokens_post_padded,
473-
False,
474-
topk_ids.shape[1],
475-
config,
476-
compute_type=compute_type,
477-
use_fp8=use_fp8)
478-
479-
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
480-
481-
invoke_fused_moe_kernel(intermediate_cache2,
482-
w2,
483-
intermediate_cache3,
484-
a2_scale,
485-
w2_scale,
486-
topk_weights,
487-
topk_ids,
488-
sorted_token_ids,
489-
expert_ids,
490-
num_tokens_post_padded,
491-
True,
492-
1,
493-
config,
494-
compute_type=compute_type,
495-
use_fp8=use_fp8)
496-
497461
if inplace:
498-
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
499-
dim=1,
500-
out=hidden_states)
501-
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
502-
dim=1)
462+
out_hidden_states = hidden_states
463+
else:
464+
out_hidden_states = torch.empty_like(hidden_states)
465+
466+
for chunk in range((num_tokens // CHUNK_SIZE) + 1):
467+
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
468+
min((chunk + 1) * CHUNK_SIZE,
469+
num_tokens))
470+
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
471+
tokens_in_chunk, _ = curr_hidden_states.shape
472+
473+
if tokens_in_chunk == 0:
474+
break
475+
476+
if tokens_in_chunk < CHUNK_SIZE:
477+
# will only happen in the last chunk
478+
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
479+
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
480+
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
481+
482+
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
483+
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
484+
485+
sorted_token_ids, expert_ids, num_tokens_post_padded = (
486+
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
487+
488+
invoke_fused_moe_kernel(curr_hidden_states,
489+
w1,
490+
intermediate_cache1,
491+
a1_scale,
492+
w1_scale,
493+
curr_topk_weights,
494+
curr_topk_ids,
495+
sorted_token_ids,
496+
expert_ids,
497+
num_tokens_post_padded,
498+
False,
499+
topk_ids.shape[1],
500+
config,
501+
compute_type=compute_type,
502+
use_fp8=use_fp8)
503+
504+
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
505+
506+
invoke_fused_moe_kernel(intermediate_cache2,
507+
w2,
508+
intermediate_cache3,
509+
a2_scale,
510+
w2_scale,
511+
curr_topk_weights,
512+
curr_topk_ids,
513+
sorted_token_ids,
514+
expert_ids,
515+
num_tokens_post_padded,
516+
True,
517+
1,
518+
config,
519+
compute_type=compute_type,
520+
use_fp8=use_fp8)
521+
522+
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
523+
dim=1,
524+
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
525+
return out_hidden_states
503526

504527

505528
def fused_moe(

0 commit comments

Comments
 (0)