From 97abeb1daac6007526af435244d3f7047db272cd Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Tue, 8 Jul 2025 20:03:35 -0700 Subject: [PATCH] [feat] enable SM100 CUTLASS block scaled group gemm for smaller batch sizes (#20640) Signed-off-by: Duncan Moss --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 10 ++++------ vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d771a7a54cfc1..de588d512739d 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -522,16 +522,14 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, return out.to(dtype=out_dtype) -def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor, - w1: torch.Tensor, +def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor, w2: torch.Tensor) -> bool: - def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int): - return M >= 128 and N % 128 == 0 and K % 128 == 0 + def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): + return N % 128 == 0 and K % 128 == 0 - m = hidden_states.size(0) _, K, N = w2.size() - if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K): + if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K): logger.debug( "CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") return False diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fbbccbb34d902..d0ff44a38a4aa 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1180,7 +1180,7 @@ def fused_experts( apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): + and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)): assert apply_router_weight_on_input is False return run_cutlass_block_scaled_fused_experts( a=hidden_states,