Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sleepcoo committed Feb 16, 2025
1 parent d802719 commit 2aa081b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,11 @@ def grouped_gemm_triton_kernel(
if group_k > 0 and group_n > 0:
a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0
offs_bsn = offs_bn // group_n
b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1
b_scale_ptrs = (
scale_b
+ (expert_id * bs_stride_0)
+ (n_range_start + offs_bsn) * bs_stride_1
)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def forward(
use_fp8_w8a8,
scale_a,
scale_b,
block_shape=self.quant_method.block_quant,
block_shape=self.quant_method.quant_config.weight_block_size,
)
return c

Expand Down Expand Up @@ -435,6 +435,9 @@ def _load_fp8_scale(


class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def __init__(self):
self.block_quant = False

def create_weights(
self,
layer: torch.nn.Module,
Expand Down

0 comments on commit 2aa081b

Please sign in to comment.