diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d771a7a54cfc1..de588d512739d 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -522,16 +522,14 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, return out.to(dtype=out_dtype) -def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor, - w1: torch.Tensor, +def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor, w2: torch.Tensor) -> bool: - def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int): - return M >= 128 and N % 128 == 0 and K % 128 == 0 + def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): + return N % 128 == 0 and K % 128 == 0 - m = hidden_states.size(0) _, K, N = w2.size() - if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K): + if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K): logger.debug( "CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") return False diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fbbccbb34d902..d0ff44a38a4aa 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1180,7 +1180,7 @@ def fused_experts( apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): + and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)): assert apply_router_weight_on_input is False return run_cutlass_block_scaled_fused_experts( a=hidden_states,