[feat] enable SM100 CUTLASS block scaled group gemm for smaller batch sizes (#20640)

Signed-off-by: Duncan Moss <djm.moss@gmail.com>
This commit is contained in:
Duncan Moss 2025-07-08 20:03:35 -07:00 committed by GitHub
parent 34dad19e7b
commit 97abeb1daa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 7 deletions

View File

@ -522,16 +522,14 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
return out.to(dtype=out_dtype)
def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor,
w1: torch.Tensor,
def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
w2: torch.Tensor) -> bool:
def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int):
return M >= 128 and N % 128 == 0 and K % 128 == 0
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
return N % 128 == 0 and K % 128 == 0
m = hidden_states.size(0)
_, K, N = w2.size()
if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K):
if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K):
logger.debug(
"CutlassBlockScaledGroupedGemm disabled: unalinged problem size.")
return False

View File

@ -1180,7 +1180,7 @@ def fused_experts(
apply_router_weight_on_input=apply_router_weight_on_input,
)
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)):
and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)):
assert apply_router_weight_on_input is False
return run_cutlass_block_scaled_fused_experts(
a=hidden_states,