From 5923ab9524e32006ffb9354c5340b6988a45fe3e Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Thu, 10 Jul 2025 19:39:18 -0700 Subject: [PATCH] [fix]: disable cutlass block scaled group gemm for EP (#20781) Signed-off-by: Duncan Moss --- .../moe/blockwise_scaled_group_mm_sm100.cu | 9 +++--- .../layers/fused_moe/cutlass_moe.py | 29 +++++++++++++++++-- .../layers/fused_moe/fused_moe.py | 5 ++-- 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu index 236d76ed52081..6c8f6309ef43f 100644 --- a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu @@ -201,11 +201,10 @@ void run_blockwise_scaled_group_mm( reinterpret_cast( layout_sfb.data_ptr())}; - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = a_ptrs.get_device(); - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); + int device_id = a_ptrs.device().index(); + static const cutlass::KernelHardwareInfo hw_info{ + device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + device_id)}; // Epilogue Arguments typename GemmKernel::EpilogueArguments epilogue_args{ diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 623003f65adaa..d6a30e3426950 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -553,8 +553,10 @@ def cutlass_moe_fp4(a: torch.Tensor, return out.to(dtype=out_dtype) -def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor, - w2: torch.Tensor) -> bool: +def _valid_cutlass_block_scaled_grouped_gemm( + w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str, + apply_router_weight_on_input: bool, + expert_map: Optional[torch.Tensor]) -> bool: def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): return N % 128 == 0 and K % 128 == 0 @@ -570,6 +572,29 @@ def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor, "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).") return False + if expert_map is not None: + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: expert_parallel is" + " not supported.") + return False + + if activation != "silu": + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: only activation silu is" + " supported.") + return False + + if apply_router_weight_on_input: + logger.debug("CutlassBlockScaledGroupedGemm disabled:" + " apply_router_weight_on_input is not supported.") + return False + + if inplace: + logger.debug( + "CutlassBlockScaledGroupedGemm disabled: inplace is not supported." + ) + return False + return True diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1947a3d5fac11..e16cc9e8507d5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1192,8 +1192,9 @@ def fused_experts( apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)): - assert apply_router_weight_on_input is False + and _valid_cutlass_block_scaled_grouped_gemm( + w1, w2, inplace, activation, apply_router_weight_on_input, + expert_map)): return run_cutlass_block_scaled_fused_experts( a=hidden_states, w1=w1,