Skip to content

Commit

Permalink
This commit uses the DRL predictor to predict future expert selection.
Browse files Browse the repository at this point in the history
Prefill stage is not yet complete
  • Loading branch information
gnpinkert committed Sep 28, 2024
1 parent 7b33075 commit 07787f3
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 72 deletions.
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
FusedMoEMethodBase,
MoeGpuBuffer,
DebugCudaEvent)
from vllm.triton_utils import HAS_TRITON

__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"MoeGpuBuffer",
"DebugCudaEvent",
]

if HAS_TRITON:
Expand Down
26 changes: 16 additions & 10 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.model_executor.layers.fused_moe import DebugCudaEvent, MoeGpuBuffer

logger = init_logger(__name__)

Expand Down Expand Up @@ -456,11 +457,10 @@ def get_config_dtype_str(dtype: torch.dtype,


def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
moe_gpu_buffer: MoeGpuBuffer,
moe_events: DebugCudaEvent,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
stream: torch.cuda.Stream,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8_w8a8: bool = False,
Expand All @@ -470,8 +470,9 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):

stream.synchronize()
# Check constraints.
w1 = moe_gpu_buffer.w13_gpu
w2 = moe_gpu_buffer.w2_gpu
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
#assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
Expand Down Expand Up @@ -545,8 +546,11 @@ def fused_experts(hidden_states: torch.Tensor,

sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))


"""
print(f"Top_k IDs: {curr_topk_ids}")
print(f"sorted_token_ids IDs: {sorted_token_ids}")
print(f"expert ids: {expert_ids}")
"""
invoke_fused_moe_kernel(curr_hidden_states,
w1,
intermediate_cache1,
Expand Down Expand Up @@ -587,9 +591,10 @@ def fused_experts(hidden_states: torch.Tensor,
dim=1,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx])

del w1
del w2
torch.cuda.empty_cache()
# we could trigger events after the kernels are over, but i don't know how well the code works to determine
# how often it will loop
moe_events.mlp_w2_finished_event.record()

return out_hidden_states


Expand Down Expand Up @@ -660,7 +665,8 @@ def fused_moe(
unique_indices = torch.unique(topk_ids.flatten())

# This number needs to match the number of experts in the layer
if unique_indices.shape[0] != 64:
# 64 for deepseek, 8 for Mixtral
if unique_indices.shape[0] != 8:
sorted_elements, _ = torch.sort(unique_indices)
rank_mapping = {elem.item(): rank for rank, elem in enumerate(sorted_elements)}
for old_value, new_value in rank_mapping.items():
Expand Down
106 changes: 72 additions & 34 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,37 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs

import numpy as np
logger = init_logger(__name__)


class DebugCudaEvent:
def __init__(self, topk: int):
self._topk_decided_event = torch.cuda.Event()
self.mlp_w13_finished_event = torch.cuda.Event()
self.mlp_w2_finished_event = torch.cuda.Event()
self.experts = np.full((topk,), fill_value=0, dtype=np.int64)
self.is_first_layer = True

def reset_events(self):
self._topk_decided_event = torch.cuda.Event()
self.mlp_w13_finished_event = torch.cuda.Event()
self.mlp_w2_finished_event = torch.cuda.Event()

def triggerTopkEvent(self, experts: torch.Tensor):
self.experts = experts[0].cpu().numpy()
self._topk_decided_event.record()


class MoeGpuBuffer:
def __init__(self, w13_shape: tuple[int, int, int], w2_shape: tuple[int, int, int]):
assert w13_shape[0] == w2_shape[0], "Moe GPU buffers must have the same number of experts"
self.w13_gpu = torch.nn.Parameter(torch.zeros(w13_shape), requires_grad=False)
self.w2_gpu = torch.nn.Parameter(torch.zeros(w2_shape), requires_grad=False)
self.expert_ids: List[int] = []
self.load_predicted_experts_stream = torch.cuda.Stream()


class FusedMoEMethodBase(QuantizeMethodBase):

@abstractmethod
Expand All @@ -31,6 +58,13 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor,
raise NotImplementedError


def find_invalid_indices(predicted_expert_ids: torch.Tensor, actual_expert_ids: torch.Tensor) -> List[int] :
mask = torch.isin(predicted_expert_ids, actual_expert_ids)
invalid_indices = torch.nonzero(~mask, as_tuple=True)[0]

return invalid_indices.tolist()


class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""

Expand Down Expand Up @@ -64,9 +98,8 @@ def apply(self,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
stream: torch.cuda.Stream,
w13_gpu: torch.nn.Parameter,
w2_gpu: torch.nn.Parameter,
moe_gpu_buffer: MoeGpuBuffer,
router_event: DebugCudaEvent,
w1_cpu: torch.nn.Parameter,
w2_cpu: torch.nn.Parameter,
topk_group: Optional[int] = None,
Expand All @@ -76,28 +109,26 @@ def apply(self,
layer=layer,
router_logits=router_logits,
top_k=top_k,
w1=w13_gpu,
w2=w2_gpu,
moe_gpu_buffer=moe_gpu_buffer,
router_event=router_event,
w1_cpu=w1_cpu,
w2_cpu=w2_cpu,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
stream=stream)
num_expert_group=num_expert_group,)

def forward_cuda(self,
layer: torch.nn.Module,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
moe_gpu_buffer: MoeGpuBuffer,
router_event: DebugCudaEvent,
w1_cpu: torch.Tensor,
w2_cpu: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
stream: torch.cuda.Stream,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:

Expand All @@ -113,28 +144,37 @@ def forward_cuda(self,
topk_group=topk_group,
num_expert_group=num_expert_group)

stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
unique_indices = torch.unique(topk_ids.flatten())

if router_event.is_first_layer:
w1_cpu.pin_memory()
w2_cpu.pin_memory()
w1 = w1_cpu[unique_indices.to('cpu')].to('cuda')
w2 = w2_cpu[unique_indices.to('cpu')].to('cuda')

if unique_indices.shape[0] != 64:
sorted_elements, _ = torch.sort(unique_indices)
rank_mapping = {elem.item(): rank for rank, elem in enumerate(sorted_elements)}
for old_value, new_value in rank_mapping.items():
topk_ids[topk_ids == old_value] = new_value
moe_gpu_buffer.w13_gpu[:top_k, :, :] = w1_cpu[topk_ids[0].tolist()].to('cuda', non_blocking=True)
moe_gpu_buffer.w2_gpu[:top_k, :, :] = w2_cpu[topk_ids[0].tolist()].to('cuda', non_blocking=True)
topk_ids[0] = torch.arange(0, top_k)

else:
router_event.triggerTopkEvent(topk_ids)
invalid_indices = find_invalid_indices(torch.tensor(moe_gpu_buffer.expert_ids).to('cuda'), topk_ids[0])
for index in range(len(topk_ids[0])):
actual_id = topk_ids[0][index]
if actual_id not in moe_gpu_buffer.expert_ids:
with torch.cuda.stream(moe_gpu_buffer.load_predicted_experts_stream):
replaced_index = invalid_indices.pop(0) # Pop the first value
w1_cpu.pin_memory()
w2_cpu.pin_memory()

moe_gpu_buffer.w13_gpu[replaced_index, :, :] = w1_cpu[actual_id].to('cuda', non_blocking=True)
moe_gpu_buffer.w2_gpu[replaced_index, :, :] = w2_cpu[actual_id].to('cuda', non_blocking=True)
moe_gpu_buffer.expert_ids[replaced_index] = actual_id
topk_ids[0][index] = replaced_index
else:
new_index = moe_gpu_buffer.expert_ids.index(actual_id)
topk_ids[0][index] = new_index

return fused_experts(hidden_states=x,
w1=w1,
w2=w2,
stream=stream,
moe_gpu_buffer=moe_gpu_buffer,
topk_weights=topk_weights,
topk_ids=topk_ids,
moe_events=router_event,
inplace=True)

def forward_cpu(self, *args, **kwargs):
Expand Down Expand Up @@ -315,28 +355,26 @@ def select_experts(hidden_states: torch.Tensor,

def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
w13_gpu: torch.nn.Parameter,
w2_gpu: torch.nn.Parameter,
moe_gpu_buffer: MoeGpuBuffer,
router_event: DebugCudaEvent,
w1_cpu: torch.nn.Parameter,
w2_cpu: torch.nn.Parameter,
stream: torch.cuda.Stream):
w2_cpu: torch.nn.Parameter):
assert self.quant_method is not None

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
w13_gpu=w13_gpu,
w2_gpu=w2_gpu,
moe_gpu_buffer=moe_gpu_buffer,
router_event=router_event,
w1_cpu=w1_cpu,
w2_cpu=w2_cpu,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
stream=stream)
num_expert_group=self.num_expert_group)

if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
Expand Down
Loading

0 comments on commit 07787f3

Please sign in to comment.