|
8 | 8 | import triton
|
9 | 9 | import triton.language as tl
|
10 | 10 |
|
| 11 | +import vllm.envs as envs |
11 | 12 | from vllm import _custom_ops as ops
|
12 | 13 | from vllm.logger import init_logger
|
13 | 14 |
|
@@ -420,13 +421,12 @@ def fused_experts(hidden_states: torch.Tensor,
|
420 | 421 | torch.float32, torch.float16, torch.bfloat16
|
421 | 422 | ]
|
422 | 423 |
|
423 |
| - M, _ = hidden_states.shape |
| 424 | + num_tokens, _ = hidden_states.shape |
424 | 425 | E, N, _ = w1.shape
|
425 |
| - |
426 |
| - if M > 65536: |
427 |
| - # https://github.com/vllm-project/vllm/issues/5938 |
428 |
| - raise ValueError("MoE kernel does not support more than 65536 tokens, " |
429 |
| - f"but got {M}") |
| 426 | + # We execute the fused_moe kernel in chunks to circumvent this issue: |
| 427 | + # https://github.com/vllm-project/vllm/issues/5938 |
| 428 | + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE |
| 429 | + M = min(num_tokens, CHUNK_SIZE) |
430 | 430 |
|
431 | 431 | if override_config:
|
432 | 432 | config = override_config
|
@@ -455,51 +455,74 @@ def fused_experts(hidden_states: torch.Tensor,
|
455 | 455 | device=hidden_states.device,
|
456 | 456 | dtype=hidden_states.dtype)
|
457 | 457 |
|
458 |
| - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( |
459 |
| - topk_ids, config['BLOCK_SIZE_M'], E) |
460 | 458 | compute_type = (tl.bfloat16
|
461 | 459 | if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
462 | 460 |
|
463 |
| - invoke_fused_moe_kernel(hidden_states, |
464 |
| - w1, |
465 |
| - intermediate_cache1, |
466 |
| - a1_scale, |
467 |
| - w1_scale, |
468 |
| - topk_weights, |
469 |
| - topk_ids, |
470 |
| - sorted_token_ids, |
471 |
| - expert_ids, |
472 |
| - num_tokens_post_padded, |
473 |
| - False, |
474 |
| - topk_ids.shape[1], |
475 |
| - config, |
476 |
| - compute_type=compute_type, |
477 |
| - use_fp8=use_fp8) |
478 |
| - |
479 |
| - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) |
480 |
| - |
481 |
| - invoke_fused_moe_kernel(intermediate_cache2, |
482 |
| - w2, |
483 |
| - intermediate_cache3, |
484 |
| - a2_scale, |
485 |
| - w2_scale, |
486 |
| - topk_weights, |
487 |
| - topk_ids, |
488 |
| - sorted_token_ids, |
489 |
| - expert_ids, |
490 |
| - num_tokens_post_padded, |
491 |
| - True, |
492 |
| - 1, |
493 |
| - config, |
494 |
| - compute_type=compute_type, |
495 |
| - use_fp8=use_fp8) |
496 |
| - |
497 | 461 | if inplace:
|
498 |
| - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), |
499 |
| - dim=1, |
500 |
| - out=hidden_states) |
501 |
| - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), |
502 |
| - dim=1) |
| 462 | + out_hidden_states = hidden_states |
| 463 | + else: |
| 464 | + out_hidden_states = torch.empty_like(hidden_states) |
| 465 | + |
| 466 | + for chunk in range((num_tokens // CHUNK_SIZE) + 1): |
| 467 | + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, |
| 468 | + min((chunk + 1) * CHUNK_SIZE, |
| 469 | + num_tokens)) |
| 470 | + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] |
| 471 | + tokens_in_chunk, _ = curr_hidden_states.shape |
| 472 | + |
| 473 | + if tokens_in_chunk == 0: |
| 474 | + break |
| 475 | + |
| 476 | + if tokens_in_chunk < CHUNK_SIZE: |
| 477 | + # will only happen in the last chunk |
| 478 | + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] |
| 479 | + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] |
| 480 | + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] |
| 481 | + |
| 482 | + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] |
| 483 | + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] |
| 484 | + |
| 485 | + sorted_token_ids, expert_ids, num_tokens_post_padded = ( |
| 486 | + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) |
| 487 | + |
| 488 | + invoke_fused_moe_kernel(curr_hidden_states, |
| 489 | + w1, |
| 490 | + intermediate_cache1, |
| 491 | + a1_scale, |
| 492 | + w1_scale, |
| 493 | + curr_topk_weights, |
| 494 | + curr_topk_ids, |
| 495 | + sorted_token_ids, |
| 496 | + expert_ids, |
| 497 | + num_tokens_post_padded, |
| 498 | + False, |
| 499 | + topk_ids.shape[1], |
| 500 | + config, |
| 501 | + compute_type=compute_type, |
| 502 | + use_fp8=use_fp8) |
| 503 | + |
| 504 | + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) |
| 505 | + |
| 506 | + invoke_fused_moe_kernel(intermediate_cache2, |
| 507 | + w2, |
| 508 | + intermediate_cache3, |
| 509 | + a2_scale, |
| 510 | + w2_scale, |
| 511 | + curr_topk_weights, |
| 512 | + curr_topk_ids, |
| 513 | + sorted_token_ids, |
| 514 | + expert_ids, |
| 515 | + num_tokens_post_padded, |
| 516 | + True, |
| 517 | + 1, |
| 518 | + config, |
| 519 | + compute_type=compute_type, |
| 520 | + use_fp8=use_fp8) |
| 521 | + |
| 522 | + torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), |
| 523 | + dim=1, |
| 524 | + out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) |
| 525 | + return out_hidden_states |
503 | 526 |
|
504 | 527 |
|
505 | 528 | def fused_moe(
|
|
0 commit comments