From 2aa081b455cb3207b21d59fe28e72c8f846171f1 Mon Sep 17 00:00:00 2001 From: sleepcoo Date: Sun, 16 Feb 2025 18:45:07 +0800 Subject: [PATCH] fix bug --- python/sglang/srt/layers/moe/ep_moe/kernels.py | 6 +++++- python/sglang/srt/layers/moe/ep_moe/layer.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index ff28a167b40..80086f1f74f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -273,7 +273,11 @@ def grouped_gemm_triton_kernel( if group_k > 0 and group_n > 0: a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0 offs_bsn = offs_bn // group_n - b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1 + b_scale_ptrs = ( + scale_b + + (expert_id * bs_stride_0) + + (n_range_start + offs_bsn) * bs_stride_1 + ) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 1c740cd8a01..a2433d95ede 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -89,7 +89,7 @@ def forward( use_fp8_w8a8, scale_a, scale_b, - block_shape=self.quant_method.block_quant, + block_shape=self.quant_method.quant_config.weight_block_size, ) return c @@ -435,6 +435,9 @@ def _load_fp8_scale( class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): + def __init__(self): + self.block_quant = False + def create_weights( self, layer: torch.nn.Module,